/* 
 * equations.c
 *
 * Functions defining the equations modelling a neuron.
 */

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include "mpi.h"
#include "equations.h"
#include "utility.h"
#include "nscanf.h"
#include "read_constants.h"
#include "globals.h"


/* List all constants defining your equations that you want to read 
 * in 
 */
double 
    I_APP,
    G_NA,
    G_K,
    G_L,
    V_NA,
    V_K,
    V_L,
    V_S,
    G_S,
    TAU_S;

/* How many different neuron models are you using? */
int MODEL_COUNT = 1;

/* How many equations are used to model each type of neuron? */
/* Don't count synaptic current equations                    */
int EQUATION_COUNTS[] = {3};

/* Which equation number (0,1,2,...) is the membrane voltage? */
/* Current version assumes this is 0                          */
int VOLTAGE_EQN_OFFSETS[] = {0};

/********************************************************************/
/******** Declarations, definitions, etc beyond this point **********/
/******** may not be needed for your models                **********/
/********************************************************************/

double  sqrt_step_size;

/*===================================================================*/
/* Macros for simplifying array element and struct member access */
/* You shouldn't need to modify these                            */
#define EQN_COUNT(global_neuron_rank) \
            (EQUATION_COUNTS[neuron_types[global_neuron_rank]])

#define MODEL_FCN(global_neuron_rank) \
            (*(models[neuron_types[global_neuron_rank]]))

/* Get the type of the jth neuron sending synaptic input to the neuron
 * with local rank local_neuron_rank */
#define INPUT_NEURON_TYPE(local_neuron_rank, j) \
            (neuron_types[syn_connections[local_neuron_rank].links[j]])

/* Get the rank of the jth neuron sending synaptic input to the neuron
 * with local rank  local_neuron_rank */
#define INPUT_NEURON_RANK(local_neuron_rank, j) \
            (syn_connections[local_neuron_rank].links[j])

#define INPUT_NEURON_SIGN(local_neuron_rank, j) \
            (syn_connections[local_neuron_rank].signs[j])

/* Since the syn_input array name is only used in the V_prime functions
 * this macro can't be used in other files */
#define SYNAPTIC_INPUT(local_neuron_rank, j) \
            (syn_input[syn_connections[local_neuron_rank].links[j]])

/*===================================================================*/
void Setup_f(void) {

    sqrt_step_size = sqrt(step_size);  /* step_size is global */
}  /* Setup_f */


/*===================================================================*/
void Cleanup_f(void) {

    /* Do nothing */
}  /* Cleanup_f */


/*===================================================================*/
/* This function will compute f(t, y).
 * This is a completely local function at this point.
 * syn_input is global.  other args are local.
 */
void Compute_f(
         double      y[]                /* IN  */,
         double      syn_input[]        /* IN  */,
         double      answer[]           /* OUT */,
         double      syn_answer[]       /* OUT */) {

    int global_neuron_rank, local_neuron_rank;
    int i = 0;  /* index of first equation associated with Neuron */
#   ifdef DEBUG
    printf("Process %d > In Compute_f\n", my_rank);
    printf("Process %d > my_first_neuron = %d, my_last_neuron = %d\n",
            my_rank, my_first_neuron, my_last_neuron);
#   endif

    for (global_neuron_rank = my_first_neuron; 
         global_neuron_rank <= my_last_neuron;
         global_neuron_rank++) {
        local_neuron_rank = global_neuron_rank - my_first_neuron;

        Compute_V_prime(y, answer, syn_input, i, local_neuron_rank);
        Compute_h_prime(y, answer, i+1);
        Compute_n_prime(y, answer, i+2);
        Compute_s_prime(y, syn_input, syn_answer, i+3,
                 local_neuron_rank, global_neuron_rank);

#       ifdef DEBUG
        printf("Process %d > i = %d, Finished call to Model_%d\n",
            my_rank, i, neuron_types[global_neuron_rank]);
        printf("Process %d > global_neuron_rank = %d, local_neuron_rank = %d\n",
            my_rank, global_neuron_rank, local_neuron_rank);
        fflush(stdout);
#       endif

        i += EQN_COUNT(global_neuron_rank);
    }  /* for */

    return;
}  /* Compute_f */


/*===================================================================*/
/* Note:  neuron is the *local* neuron rank */
void Compute_V_prime(
         double       y[]                /* IN  */,
         double       answer[]           /* OUT */,
         double       syn_input[]        /* IN  */,
         int          i                  /* IN  */,
         int          neuron             /* IN  */) {

    /* V' = I_APP - i_na - i_k - i_l - i_syn
     *     [- SIGMA*sqrt_step_size*rand()?]
     * y[i]   = V(t),
     * y[i+1] = h(t),
     * y[i+2] = n(t),
     */

       
    register int    j;
    register double s_tot = 0.0;
    register int    dim = (syn_connections[neuron]).dimension;
    register double y_i = y[i];

#   ifdef DEBUG
    printf("Process %d > In Compute_V_prime, i = %d, neuron = %d\n",
        my_rank, i, neuron);
    printf("Process %d > dim = %d\n", my_rank, dim);
    fflush(stdout);
#   endif


    for (j = 0; j < dim; j++) { 
        s_tot += SYNAPTIC_INPUT(neuron,j)
                    *INPUT_NEURON_SIGN(neuron,j);
    }

    if (dim > 0) s_tot /= dim;

#   define M_inf(x) ( 1.0/(1.0 + exp(-0.08*((x) + 26.0))) )

#   ifdef DEBUG
{
    double i_na, i_k, i_l, i_s;
    i_na = G_NA  * pow(M_inf(y_i),3.0) * y[i+1] * (y_i - V_NA);
    i_k =  G_K   * pow(y[i+2],4.0)     *          (y_i - V_K );
    i_l =  G_L   *                                (y_i - V_L );
    i_s =  G_S   * s_tot               *          (y_i - V_S );
    printf("i = %d, i_na = %f, i_k = %f, i_l = %f, i_s = %f\n",
        i, i_na, i_k, i_l, i_s);
}
#   endif

    answer[i] = I_APP
          - G_NA  * pow(M_inf(y_i),3.0) * y[i+1] * (y_i - V_NA)
          - G_K   * pow(y[i+2],4.0)     *          (y_i - V_K )
          - G_L   *                                (y_i - V_L )
          - G_S   * s_tot               *          (y_i - V_S );

#   ifdef DEBUG
    printf("Process %d > i = %d, v_prime = %f\n", 
           my_rank, i, answer[i]);
    fflush(stdout);
#   endif

}  /* Compute_V_prime */


/*===================================================================*/
void Compute_h_prime(
         double   y[]           /* IN  */,
         double   answer[]      /* OUT */,
         int      i             /* IN  */) {

    /* h' = (H_inf(V) - h(t))/H_tau(V)
     * y[i-1] = V(t),
     * y[i]   = h(t),
     * y[i+1] = n(t),
     */

#   define H_inf(x) ( 1.0/(1.0 + exp(0.13*((x) + 38.0))) )

#   define H_tau(x) ( 0.6/(1.0 + exp(-0.12*((x) + 67.0))) )

    answer[i] = ( H_inf(y[i-1]) - y[i] ) / H_tau(y[i-1]);

}  /* Compute_h_prime */


/*===================================================================*/
void Compute_n_prime(
         double   y[]           /* IN  */,
         double   answer[]      /* OUT */,
         int      i             /* IN  */) {

    /* n' = ( N_inf(V) - n(t) )/N_tau(V)
     * y[i-2] = V(t),
     * y[i-1] = h(t),
     * y[i]   = n(t)
     */

#   define N_inf(x) ( 1.0/(1 + exp(-0.045*((x) + 10.0))) )

#   define N_tau(x) ( 0.5 + 2.0/(1 + exp(0.045*((x) - 50.0))) )

    answer[i] = ( N_inf(y[i-2]) - y[i] ) / N_tau(y[i-2]);
}  /* Compute_n_prime */


/*===================================================================*/
void Compute_s_prime(
         double   y[]                 /* IN  */,
         double   syn_input[]         /* IN  */,
         double   syn_answer[]        /* OUT */,
         int      i                   /* IN  */,
         int      neuron              /* IN  */,
         int      global_neuron_rank  /* IN  */) {

    /* s' = f(V) * (1 - s(t)) - s(t)/TAU_S
     * y[i-3] = V(t),
     * y[i-2] = h(t),
     * y[i-1] = n(t)
     */

#   define F(x) ( 1.0/(1.0 + exp(-(x))) )

    syn_answer[neuron] = F(y[i-3])*(1.0 - syn_input[global_neuron_rank]) 
                        - syn_input[global_neuron_rank]/TAU_S;

}  /* Compute_s_prime */
