/* 
 * solver.c 
 *
 * Uses parallelized 4th order Runge-Kutta method to solve a system
 *    ODE's modelling a network of neurons.
 */

#include "solver.h"
#include <stdlib.h>
#include <math.h>
#include "mpi.h"
#include "solver.h"
#include "equations.h"
#include "output.h"
#include "utility.h"
#include "comm.h"
#include "prof.h"
#include "globals.h"

/* #define DEBUG */


/*===================================================================*/
void Runge_kutta(
         double        y_init[]          /* IN  */,
         double        syn_init[]        /* IN  */) {

#   ifdef DEBUG
    char message[80];
#   endif

    double *y_i,           /* local for previous timestep   */
           *y_i_plus_1,    /* local for current timestep    */
           *local_k,       /* local auxiliary for RK method */
           *my_temp_y,     /* local auxiliary for RK method */
           *syn_i,         /* global syn output prev step   */
           *my_syn_plus_1, /* local syn output curr step    */
           *local_syn_k,   /* local aux synaptic output     */
           *temp_syn;      /* global aux synaptic output    */    

    int    m;
    int    step;
    double t;
    double step_size_over_6;
    double step_size_over_2;
    double step_size_over_3;

#ifdef PROF
    Setup_profiling();
#endif

    y_i           = (double*)malloc(sizeof(double)*num_my_equations);
    Check_malloc(y_i, "y_i");
    my_temp_y     = (double*)malloc(sizeof(double)*num_my_equations);
    Check_malloc(my_temp_y, "my_temp_y");
    y_i_plus_1    = (double*)malloc(sizeof(double)*num_my_equations);
    Check_malloc(y_i_plus_1, "y_i_plus_1"); 
    local_k       = (double*)malloc(sizeof(double)*num_my_equations);
    Check_malloc(local_k, "local_k");

    /* These allocations assume synaptic output is a scalar */
    /* Use syn_init for storing syn_i                          */
    /* syn_i = (double*) malloc(sizeof(double)*total_neurons); */
    /* Check_malloc(syn_i, "syn_i");                           */
    my_syn_plus_1 = (double*) malloc(sizeof(double)*num_my_neurons);
    Check_malloc(my_syn_plus_1, "my_syn_plus_1"); 
    local_syn_k = (double*) malloc(sizeof(double)*num_my_neurons);
    Check_malloc(local_syn_k, "local_syn_k"); 
    temp_syn = (double*) malloc(sizeof(double)*total_neurons);
    Check_malloc(temp_syn, "temp_syn"); 


    /* Copy initial conditions into y_i, syn_i */
    memcpy(y_i, y_init + my_first_var, 
           num_my_equations*sizeof(double) );
    /* memcpy(syn_i, syn_init, total_neurons*sizeof(double)); */
    syn_i = syn_init;
    
    t = time_init;

    Setup_f();

    step_size_over_6 = step_size/6.0;
    step_size_over_2 = step_size/2.0;
    step_size_over_3 = step_size/3.0;

#   ifdef PROF
    Start_output_prof();
#   endif
#   ifdef DIST_OUTPUT
    Copy_to_output_buffer(t, y_i);
#   else
    if (my_rank == 0) Copy_to_output_buffer(t, y_init);
#   endif  /* DIST_OUTPUT */
#   ifdef PROF
    Finish_output_prof();
#   endif

    /* Initialize y_i_plus_1, my_syn_plus_1 */
    memcpy(y_i_plus_1, y_i, num_my_equations*sizeof(double));
    memcpy(my_syn_plus_1, syn_i + my_first_neuron,
        num_my_neurons*sizeof(double));

    /*==============================================================*/
    for (step = 1; step <= num_steps; step++) {
        
        t = step*step_size;

#       ifdef PROF
        Start_calc_prof();
#       endif

        /* My part of y_i_plus_1 and my_syn_plus_1 has values of y_i 
         * from last time step.
         */

        /* Compute k_1 = f(t, y_i) */
        Compute_f(y_i, syn_i, local_k, local_syn_k);

#       ifdef DEBUG
        sprintf(message, "k_1 at step %d =", step);
        Print_global_vector(message, local_k, num_my_equations,
                            total_equations, my_rank);
        sprintf(message, "synaptic k_1 at step %d =", step);
        Print_global_vector(message, local_syn_k, num_my_neurons,
                            total_neurons, my_rank);
#       endif

        for (m = 0; m < num_my_equations; m++) {
            /* Add k_1*h/6 to y_(i+1) */
            y_i_plus_1[m] += step_size_over_6*local_k[m];

            /* Add k_1*h/2 to y_i for computing k_2 */
            my_temp_y[m] = y_i[m] + step_size_over_2*local_k[m];
        }
        for (m = 0; m < num_my_neurons; m++) {
            my_syn_plus_1[m] += step_size_over_6*local_syn_k[m];
            local_syn_k[m] = syn_i[m + my_first_neuron] + 
                             step_size_over_2*local_syn_k[m];
        }

#       ifdef PROF
        Finish_calc_prof();
        Start_all_gather_prof();
#       endif
        ALL_GATHER(local_syn_k, num_my_neurons, MPI_DOUBLE,
            temp_syn, num_my_neurons, MPI_DOUBLE, comm);
#       ifdef PROF
        Finish_all_gather_prof();
        Start_calc_prof();
#       endif

        /* Compute k_2 = f(t + h/2, y_i + k_1*h/2) */
        Compute_f(my_temp_y, temp_syn, local_k, local_syn_k); 

#       ifdef DEBUG
        sprintf(message, "k_2 at step %d =", step);
        Print_global_vector(message, local_k, num_my_equations,
                            total_equations, my_rank);
        sprintf(message, "synaptic k_2 at step %d =", step);
        Print_global_vector(message, local_syn_k, num_my_neurons,
                            total_neurons, my_rank);
#       endif

        for (m = 0; m < num_my_equations; m++) {
            /* Add k_2*h/3 to y_(i+1) */
            y_i_plus_1[m] += step_size_over_3*local_k[m];

            /* Add k_2*h/2 to y_i for computing k_3 */
            my_temp_y[m] = y_i[m] + step_size_over_2*local_k[m];
        }
        for (m = 0; m < num_my_neurons; m++) {
            my_syn_plus_1[m] += step_size_over_3*local_syn_k[m];
            local_syn_k[m] = syn_i[m + my_first_neuron] + 
                             step_size_over_2*local_syn_k[m];
        }

#       ifdef PROF
        Finish_calc_prof();
        Start_all_gather_prof();
#       endif
        ALL_GATHER(local_syn_k, num_my_neurons, MPI_DOUBLE,
            temp_syn, num_my_neurons, MPI_DOUBLE, comm);
#       ifdef PROF
        Finish_all_gather_prof();
        Start_calc_prof();
#       endif

        /* Compute k_3 = f(t + h/2, y_i + k_2*h/2) */
        Compute_f(my_temp_y, temp_syn, local_k, local_syn_k);    

#       ifdef DEBUG
        sprintf(message, "k_3 at step %d =", step);
        Print_global_vector(message, local_k, num_my_equations,
                            total_equations, my_rank);
#       endif

        for (m=0; m < num_my_equations; m++) {
            /* Add k_3*h/3 to y_(i+1) */
            y_i_plus_1[m] += step_size_over_3*local_k[m];

            /* Add k_3*h to y_i for computing k_4 */
            my_temp_y[m] = y_i[m] + step_size*local_k[m];
        }
        for (m = 0; m < num_my_neurons; m++) {
            my_syn_plus_1[m] += step_size_over_3*local_syn_k[m];
            local_syn_k[m] = syn_i[m + my_first_neuron] + 
                             step_size*local_syn_k[m];
        }

#       ifdef PROF
        Finish_calc_prof();
        Start_all_gather_prof();
#       endif
        ALL_GATHER(local_syn_k, num_my_neurons, MPI_DOUBLE,
            temp_syn, num_my_neurons, MPI_DOUBLE, comm);
#       ifdef PROF
        Finish_all_gather_prof();
        Start_calc_prof();
#       endif

        /* Compute k_4 = f(t + h, y_i + k_3*h) */
        Compute_f(my_temp_y, temp_syn, local_k, local_syn_k);    
        
#       ifdef DEBUG
        sprintf(message, "k_4 at step %d =", step);
        Print_global_vector(message, local_k, num_my_equations,
                            total_equations, my_rank);
        sprintf(message, "synaptic k_4 at step %d =", step);
        Print_global_vector(message, local_syn_k, num_my_neurons,
                            total_neurons, my_rank);
#       endif

        for (m = 0; m < num_my_equations; m++) {
            /* Add k_4*h/6 to y_(i+1) */
            y_i_plus_1[m] += step_size_over_6*local_k[m];
        }
        for (m = 0; m < num_my_neurons; m++) {
            my_syn_plus_1[m] += step_size_over_6*local_syn_k[m];
        }
        
#       ifdef PROF
        Finish_calc_prof();
#       endif

        /* Print data when step is a multiple of print_freq */
        if ((step % print_freq) == 0) {
#           ifdef DIST_OUTPUT
#               ifdef PROF
                Start_output_prof();
#               endif  
                Copy_to_output_buffer(t, y_i_plus_1);
#               ifdef PROF
                Finish_output_prof();
#               endif 
#           else  /* not DIST_OUTPUT */
                /* Process 0 needs all of y_i_plus_1 for output */
#               ifdef PROF
                Start_gather_prof();
#               endif
                MPI_Gather(y_i_plus_1, num_my_equations, MPI_DOUBLE, 
                    y_init, num_my_equations, MPI_DOUBLE, 0, comm);
#               ifdef PROF
                Finish_gather_prof();
#               endif
                if (my_rank == 0) {
#                   ifdef PROF
                    Start_output_prof();
#                   endif
                    Copy_to_output_buffer(t, y_init);
#                   ifdef PROF
                    Finish_output_prof();
#                   endif
                }  /* my_rank == 0 */
#           endif  /* DIST_OUTPUT */
        }  /* Print data */


#       ifdef PROF
        Start_calc_prof();
#       endif
        /* y_i = y_i_plus_1 for next time step */
        memcpy(y_i, y_i_plus_1, num_my_equations*sizeof(double));
#       ifdef PROF
        Finish_calc_prof();
#       endif

#       ifdef PROF
        Start_all_gather_prof();
#       endif
#       ifdef DEBUG
        sprintf(message,"my_syn_plus_1 before final Allgather, step %d",
            step);
        Print_local_vector(message, my_syn_plus_1, num_my_neurons, 
            my_rank);
        fflush(stdout);
#       endif  /* DEBUG */
        /* Gathers all the Synaptic variables into an array */
        ALL_GATHER(my_syn_plus_1, num_my_neurons, MPI_DOUBLE,
            syn_i, num_my_neurons, MPI_DOUBLE, comm);
#       ifdef PROF
        Finish_all_gather_prof();
#       endif
#       ifdef DEBUG
        sprintf(message, "syn_i after final Allgather, step %d", 
            step);
        Print_local_vector(message, syn_i, total_neurons, 
            my_rank);
        sprintf(message, "Global y_i_plus_1, step %d =", step);
        Print_global_vector(message, y_i_plus_1, num_my_equations,
                            total_equations, my_rank);
#       endif  /* DEBUG */

    }  /* for step . . . */
    /*==============================================================*/


    free(y_i);
    free(my_temp_y);
    free(y_i_plus_1);
    free(local_k);
    free(my_syn_plus_1);
    free(local_syn_k);
    free(temp_syn);
    
    Cleanup_f();
#   ifdef PROF
    Finish_profiling();
#   endif

    return;
}  /* Runge_kutta */
