/* 
 * equations.c
 *
 * Functions defining the equations modelling a neuron.
 */

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include "mpi.h"
#include "equations.h"
#include "utility.h"
#include "nscanf.h"
#include "read_constants.h"
#include "globals.h"


/* List all constants defining your equations that you want to read in */
double 
    SIGMA,
    G_NA,
    G_K,
    G_L,
    V_NA,
    V_K,
    V_L,
    V_IE,
    V_EE,
    V_II,
    V_EI,
    G_EI,
    G_IE,
    G_II,
    AE0,           /* Applied E voltage min                          */
    AEF,           /* Applied E voltage max                          */
    AI,            /* Applied I voltage                              */
    SIGMOID_THETA, /* Used in calc of applied voltages.  See Setup_f */
    SIGMOID_POWER; /* Ditto                                          */

/* How many different neuron models are you using? */
int MODEL_COUNT = 2;

/* How many equations are used to model each type of neuron? */
/* Don't count synaptic current equations                    */
int EQUATION_COUNTS[] = {4,4};

/* Which equation number (0,1,2,...) is the membrane voltage? */
int VOLTAGE_EQN_OFFSETS[] = {0,0};

/********************************************************************/
/******** Declarations, definitions, etc beyond this point **********/
/******** may not be needed for your models                **********/
/********************************************************************/

/* Pointers to functions that call equations modelling the different */
/* types.  Probably not needed if you have only one model            */
model_fp models[2] = {Model_E, Model_I};

#define E 0
#define I 1

/* Applied voltages.  Only contains voltages for *local* neurons. 
 * See Setup_f 
 */
double* applied_voltages;

/* These are set in the call to Setup_f */
int     e_cell_count;
int     i_cell_count;
double  g_ei, g_ie, g_ii;  /* Synaptic conductances scaled by the
                            * number of cells of each type 
                            */
double  sqrt_step_size;


/*===================================================================*/
/* Macros for simplifying array element and struct member access */
/* You shouldn't need to modify these                            */
#define EQN_COUNT(global_neuron_rank) \
            (EQUATION_COUNTS[neuron_types[global_neuron_rank]])

#define MODEL_FCN(global_neuron_rank) \
            (*(models[neuron_types[global_neuron_rank]]))

#define INPUT_NEURON_TYPE(local_neuron_rank, j) \
            (neuron_types[syn_connections[local_neuron_rank].links[j]])

/* Get the type of the jth neuron sending synaptic input to the neuron
 * with local rank local_neuron_rank */ 
#define INPUT_NEURON_RANK(local_neuron_rank, j) \
            (syn_connections[local_neuron_rank].links[j])

/* Get the rank of the jth neuron sending synaptic input to the neuron
 * with local rank  local_neuron_rank */
#define INPUT_NEURON_SIGN(local_neuron_rank, j) \
            (syn_connections[local_neuron_rank].signs[j])

/* Since the syn_input array name is only used in the V_prime functions
 * this macro can't be used in other files */
#define SYNAPTIC_INPUT(local_neuron_rank, j) \
            (syn_input[syn_connections[local_neuron_rank].links[j]])

/*===================================================================*/
void Setup_f(void) {
    int i;
    int global_neuron_rank;
    int local_neuron_rank;
    double ae_med, aa;

    sqrt_step_size = sqrt(step_size);

    e_cell_count = i_cell_count = 0;
    for (i = 0; i < total_neurons; i++)
        if (neuron_types[i] == I)
            i_cell_count++;
        else 
            e_cell_count++;

    g_ei = G_EI/e_cell_count;
    g_ie = G_IE/i_cell_count;
    g_ii = G_II/i_cell_count;

    /* This is the "sigmoid" version for E */
    applied_voltages = (double *) malloc(num_my_neurons*sizeof(double));
    Check_malloc(applied_voltages, "applied_voltages");

    ae_med = 0.5*(AE0 + AEF);
    aa = (AE0 - ae_med)/(ae_med - AEF);
    /* Start with my first neuron */
    for (global_neuron_rank = my_first_neuron; 
         global_neuron_rank <= my_last_neuron; 
         global_neuron_rank++) {
        local_neuron_rank = global_neuron_rank - my_first_neuron;
        if (neuron_types[global_neuron_rank] == I) 
            applied_voltages[local_neuron_rank] = AI;
        else 
            applied_voltages[local_neuron_rank] = AEF + (AE0-AEF)/
                (1 + aa*pow((global_neuron_rank+1)*SIGMOID_THETA/e_cell_count,
                                SIGMOID_POWER));
    }
    
}  /* Setup_f */


/*===================================================================*/
void Cleanup_f(void) {

    free(applied_voltages);
}  /* Cleanup_f */


/*===================================================================*/
/* This function will compute f(t, y).
 * This is a completely local function at this point.
 * syn_input is global.  other arrays are local.
 */
void Compute_f(
         double      y[]                /* IN  */,
         double      syn_input[]        /* IN  */,
         double      answer[]           /* OUT */,
         double      syn_answer[]       /* OUT */) {

    int global_neuron_rank, local_neuron_rank;
    int i = 0;  /* index of first equation associated with Neuron */
#   ifdef DEBUG
    printf("Process %d > In Compute_f\n", my_rank);
    printf("Process %d > my_first_neuron = %d, my_last_neuron = %d\n",
            my_rank, my_first_neuron, my_last_neuron);
#   endif

    for (global_neuron_rank = my_first_neuron; 
         global_neuron_rank <= my_last_neuron;
         global_neuron_rank++) {
        local_neuron_rank = global_neuron_rank - my_first_neuron;
        MODEL_FCN(global_neuron_rank) (y, answer, 
                syn_input, syn_answer, i, 
                local_neuron_rank, global_neuron_rank);

#       ifdef DEBUG
        printf("Process %d > i = %d, Finished call to Model_%d\n",
            my_rank, i, neuron_types[global_neuron_rank]);
        printf("Process %d > global_neuron_rank = %d, local_neuron_rank = %d\n",
            my_rank, global_neuron_rank, local_neuron_rank);
        fflush(stdout);
#       endif

        i += EQN_COUNT(global_neuron_rank);
    }  /* for */

    return;
}  /* Compute_f */


/*===================================================================*/
/* Note:  neuron is the *local* neuron rank */
void Model_E (
         double      y[]                /* IN  */,
         double      answer[]           /* OUT */, 
         double      syn_input[]        /* IN  */,
         double      syn_answer[]       /* OUT */,
         int         i                  /* IN  */,
         int         neuron             /* IN  */,
         int         global_neuron_rank /* IN  */) {

    Compute_V_prime_E(y, answer, syn_input, i, neuron);
    Compute_h_prime_E(y, answer, i+1);
    Compute_n_prime_E(y, answer, i+2);
    Compute_m_prime_E(y, answer, i+3);
    Compute_s_prime_E(y, syn_input, syn_answer, i+4, neuron, 
                      global_neuron_rank);
    return;
}  /* Model_E */


/*===================================================================*/
/* Note:  neuron is the *local* neuron rank */
void Compute_V_prime_E(
         double       y[]                /* IN  */,
         double       answer[]           /* OUT */,
         double       syn_input[]        /* IN  */,
         int          i                  /* IN  */,
         int          neuron             /* IN  */) {

    /* V' = G_NA*h*m^3*(V_NA - V) + 
     *      G_K   *n^4*(V_K  - V) +
     *      G_L       *(V_L  - V) +
     *      E_appl                +
     *      g_ie*syn_total_from_I*(V_IE - V) +
     *      [SIGMA*step_size^0.5*Gasdev(??)]
     * y[i]   = V(t), 
     * y[i+1] = h(t), 
     * y[i+2] = n(t), 
     * y[i+3] = m(t), 
     */
       
    register int    j;
    register double s_tot_I = 0.0;
    register double s_tot_E = 0.0;
    register int    dim = (syn_connections[neuron]).dimension;
    register double y_i = y[i];

#   ifdef DEBUG
    printf("Process %d > In Compute_V_prime, i = %d, neuron = %d\n",
        my_rank, i, neuron);
    printf("Process %d > dim = %d\n", my_rank, dim);
    fflush(stdout);
#   endif


    for (j = 0; j < dim; j++) { 
        if (INPUT_NEURON_TYPE(neuron,j) == E) 
            s_tot_E += SYNAPTIC_INPUT(neuron,j)
                           /*  *INPUT_NEURON_SIGN(neuron,j) */;
        else
            s_tot_I += SYNAPTIC_INPUT(neuron,j)
                           /*  *INPUT_NEURON_SIGN(neuron,j) */;
    }

#   ifdef DEBUG
{
    double i_na, i_k, i_l, i_s;
    i_na = G_NA * y[i+1] * y[i+3] * y[i+3] * y[i+3] * (V_NA - y_i);
    i_k =  G_K  * y[i+2] * y[i+2] * y[i+2] * y[i+2] * (V_K  - y_i);
    i_l =  G_L  *                                     (V_L  - y_i);
    i_s =  g_ie * s_tot_I *                           (V_IE - y_i);
    printf("Process %d > i = %d, s_tot_E = %f, s_tot_I = %f\n",
        my_rank, i, s_tot_E, s_tot_I);
    printf("Process %d > i = %d, i_na = %f, i_k = %f, i_l = %f\n",
        my_rank, i, i_na, i_k, i_l);
    printf("Process %d > i = %d, appl_volt = %f, i_s = %f\n",
        my_rank, i, applied_voltages[neuron], i_s);
    fflush(stdout);
}
#   endif

    answer[i] = G_NA * y[i+1] * y[i+3] * y[i+3] * y[i+3] * (V_NA - y_i)
              + G_K  * y[i+2] * y[i+2] * y[i+2] * y[i+2] * (V_K  - y_i)
              + G_L  *                                     (V_L  - y_i)
              + applied_voltages[neuron]
              + g_ie * s_tot_I *                           (V_IE - y_i);

#   ifdef DEBUG
    printf("Process %d > i = %d, v_prime = %f\n", 
           my_rank, i, answer[i]);
    fflush(stdout);
#   endif

}  /* Compute_V_prime_E */


/*===================================================================*/
void Compute_h_prime_E(
         double   y[]           /* IN  */,
         double   answer[]      /* OUT */,
         int      i             /* IN  */) {

    /* h' = ah(V)*(1 - h) - bh(V)*h
     * y[i-1] = V(t), 
     * y[i]   = h(t), 
     * y[i+1] = n(t), 
     * y[i+2] = m(t), 
     */

#   define ah(v) (0.128*exp(-(v+50.0)/18.0))
#   define bh(v) (4.0/(1.0+exp(-(v+27.0)/5.0)))

    answer[i] = ah(y[i-1])*(1.0 - y[i]) - bh(y[i-1])*y[i];

}  /* Compute_h_prime_E */


/*===================================================================*/
void Compute_n_prime_E(
         double   y[]           /* IN  */,
         double   answer[]      /* OUT */,
         int      i             /* IN  */) {

    /* n' = an(V)*(1 - n) - bn(V)*n
     * y[i-2] = V(t), 
     * y[i-1] = h(t), 
     * y[i]   = n(t), 
     * y[i+1] = n(t), 
     */

#   define an(v) (0.032*(v+52.0)/(1.0-exp(-(v+52.0)/5.0)))
#   define bn(v) (0.5*exp(-(v+57.0)/40.0))

    answer[i] = an(y[i-2])*(1 - y[i]) - bn(y[i-2])*y[i];

}  /* Compute_n_prime_E */


/*===================================================================*/
void Compute_m_prime_E(
         double   y[]           /* IN  */,
         double   answer[]      /* OUT */,
         int      i             /* IN  */) {

    /* m' = am(V)*(1 - m) - bm(V)*m
     * y[i-3] = V(t), 
     * y[i-2] = h(t), 
     * y[i-1] = n(t), 
     * y[i]   = m(t), 
     */

#   define am(v) (0.32*(v+54.0)/(1.0-exp(-(v+54.0)/4.0)))
#   define bm(v) (0.28*(v+27.0)/(exp((v+27.0)/5.0)-1.0))

    answer[i] = am(y[i-3])*(1 - y[i]) - bm(y[i-3])*y[i];

}  /* Compute_m_prime_E */


/*===================================================================*/
void Compute_s_prime_E(
         double   y[]                 /* IN  */,
         double   syn_input[]         /* IN  */,
         double   syn_answer[]        /* OUT */,
         int      i                   /* IN  */,
         int      neuron              /* IN  */,
         int      global_neuron_rank  /* IN  */) {

    /* s' = kE(V)*(1 - s) - s/2
     * y[i-4] = V(t), 
     * y[i-3] = h(t), 
     * y[i-2] = n(t), 
     * y[i-1] = m(t), 
     */

#   define kE(v)  (5.0*(1.0+tanh(v/4.0)))

    syn_answer[neuron] = 
        kE(y[i-4])*(1 - syn_input[global_neuron_rank]) - 
               syn_input[global_neuron_rank]/2;

}  /* Compute_s_prime_E */


/*===================================================================*/
/* Note:  neuron is the *local* neuron rank */
void Model_I (
         double      y[]                /* IN  */,
         double      answer[]           /* OUT */, 
         double      syn_input[]        /* IN  */,
         double      syn_answer[]       /* OUT */,
         int         i                  /* IN  */,
         int         neuron             /* IN  */,
         int         global_neuron_rank /* IN  */) {

    Compute_V_prime_I(y, answer, syn_input, 
                      i, neuron);
    Compute_h_prime_I(y, answer, i+1);
    Compute_n_prime_I(y, answer, i+2);
    Compute_m_prime_I(y, answer, i+3);
    Compute_s_prime_I(y, syn_input, syn_answer, i+4, neuron,
                      global_neuron_rank);
    return;
}  /* Model_E */


/*===================================================================*/
/* Note:  neuron is the *local* neuron rank */
void Compute_V_prime_I(
         double       y[]                /* IN  */,
         double       answer[]           /* OUT */,
         double       syn_input[]        /* IN  */,
         int          i                  /* IN  */,
         int          neuron             /* IN  */) {

    /* V' = G_NA*h*m^3*(V_NA - V) + 
     *      G_K   *n^4*(V_K  - V) +
     *      G_L       *(V_L  - V) +
     *      I_appl +
     *      g_ii*stot_I*(V_II - V) +
     *      g_ei*stot_E*(V_EI - V) +
     *     [- SIGMA*sqrt_step_size*gasdev(?)]
     * y[i]   = V(t), 
     * y[i+1] = h(t), 
     * y[i+2] = n(t), 
     * y[i+3] = m(t), 
     */
       
    register int    j;
    register double s_tot_I = 0.0;
    register double s_tot_E = 0.0;
    register int    dim = (syn_connections[neuron]).dimension;
    register double y_i = y[i];

#   ifdef DEBUG
    printf("Process %d > In Compute_V_prime, i = %d, neuron = %d\n",
        my_rank, i, neuron);
    printf("Process %d > dim = %d\n", my_rank, dim);
    fflush(stdout);
#   endif


    for (j = 0; j < dim; j++) { 
        if (INPUT_NEURON_TYPE(neuron,j) == E) 
            s_tot_E += SYNAPTIC_INPUT(neuron,j)
                           /*  *INPUT_NEURON_SIGN(neuron,j) */;
        else
            s_tot_I += SYNAPTIC_INPUT(neuron,j)
                           /*  *INPUT_NEURON_SIGN(neuron,j) */;
    }

#   ifdef DEBUG
{
    int my_rank;
    double i_na_i, i_k_i, i_l_i, i_si_i, i_se_i;
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
    i_na_i = G_NA * y[i+1] * y[i+3] * y[i+3] * y[i+3] * (V_NA - y_i);
    i_k_i  = G_K  * y[i+2] * y[i+2] * y[i+2] * y[i+2] * (V_K  - y_i);
    i_l_i  = G_L  *                                     (V_L  - y_i);
    i_si_i = g_ii * s_tot_I *                           (V_II - y_i);
    i_se_i = g_ei * s_tot_E *                           (V_EI - y_i);
    printf("Process %d > i = %d, s_tot_E = %f, s_tot_I = %f\n",
        my_rank, i, s_tot_E, s_tot_I);
    printf("Process %d > i = %d, i_na_i = %f, i_k_i = %f, i_l_i = %f\n",
        my_rank, i, i_na_i, i_k_i, i_l_i);
    printf("Process %d > i = %d, appl_volt = %f, i_si_i = %f, i_se_i = %f\n",
        my_rank, i, applied_voltages[neuron], i_si_i, i_se_i);
    fflush(stdout);
}
#   endif

    answer[i] = G_NA * y[i+1] * y[i+3] * y[i+3] * y[i+3] * (V_NA - y_i)
              + G_K  * y[i+2] * y[i+2] * y[i+2] * y[i+2] * (V_K  - y_i)
              + G_L  *                                     (V_L  - y_i)
              + applied_voltages[neuron]
              + g_ii * s_tot_I *                           (V_II - y_i)
              + g_ei * s_tot_E *                           (V_EI - y_i);

#   ifdef DEBUG
    printf("Process %d > i = %d, v_prime_I = %f\n", 
           my_rank, i, answer[i]);
    fflush(stdout);
#   endif

}  /* Compute_V_prime_I */


/*===================================================================*/
void Compute_h_prime_I(
         double   y[]           /* IN  */,
         double   answer[]      /* OUT */,
         int      i             /* IN  */) {

    /* h' = ah(V)*(1-h) - bh(V)*h
     * y[i-1] = V(t), 
     * y[i]   = h(t), 
     * y[i+1] = n(t), 
     * y[i+2] = m(t), 
     */

#   define ah(v) (0.128*exp(-(v+50.0)/18.0))
#   define bh(v) (4.0/(1.0+exp(-(v+27.0)/5.0)))

    answer[i] = ah(y[i-1])*(1.0 - y[i]) - bh(y[i-1])*y[i];

}  /* Compute_h_prime_I */


/*===================================================================*/
void Compute_n_prime_I(
         double   y[]           /* IN  */,
         double   answer[]      /* OUT */,
         int      i             /* IN  */) {

    /* n' = an(V)*(1 - n) - bn(V)*n
     * y[i-2] = V(t), 
     * y[i-1] = h(t), 
     * y[i]   = n(t), 
     * y[i+1] = m(t), 
     */

#   define an(v) (0.032*(v+52.0)/(1.0-exp(-(v+52.0)/5.0)))
#   define bn(v) (0.5*exp(-(v+57.0)/40.0))

    answer[i] = an(y[i-2])*(1 - y[i]) - bn(y[i-2])*y[i];

}  /* Compute_n_prime_I */


/*===================================================================*/
void Compute_m_prime_I(
         double   y[]           /* IN  */,
         double   answer[]      /* OUT */,
         int      i             /* IN  */) {

    /* m' = am(V)*(1 - m) - bm(V)*m
     * y[i-3] = V(t), 
     * y[i-2] = h(t), 
     * y[i-1] = n(t), 
     * y[i]   = m(t), 
     */

#   define am(v) (0.32*(v+54.0)/(1.0-exp(-(v+54.0)/4.0)))
#   define bm(v) (0.28*(v+27.0)/(exp((v+27.0)/5.0)-1.0))

    answer[i] = am(y[i-3])*(1 - y[i]) - bm(y[i-3])*y[i];

}  /* Compute_m_prime_I */


/*===================================================================*/
void Compute_s_prime_I(
         double   y[]                 /* IN  */,
         double   syn_input[]         /* IN  */,
         double   syn_answer[]        /* OUT */,
         int      i                   /* IN  */,
         int      neuron              /* IN  */,
         int      global_neuron_rank  /* IN  */) {

    /* s' = kI(V)*(1 - s) - s/10;
     * y[i-4] = V(t), 
     * y[i-3] = h(t), 
     * y[i-2] = n(t), 
     * y[i-1] = m(t), 
     */

#   define kI(v)  (2.0*(1.0+tanh(v/4.0)))
    syn_answer[neuron] = 
        kI(y[i-4])*(1 - syn_input[global_neuron_rank]) - 
                         syn_input[global_neuron_rank]/10;

}  /* Compute_s_prime_I */
