/*
 * input.c 
 *
 * Read in all data needed by solver *except* model equation
 * parameters
 */

#include <stdio.h>
#include <stdlib.h>
#include "mpi.h"
#include "input.h"
#include "utility.h"
#include "nscanf.h"
#include "globals.h"

static char array_name[MAX_NAME];

/* #define DEBUG */

/*===================================================================*/
void Read_data(
         double**      y_init              /* OUT */,
         double**      syn_init            /* OUT */) {

    FILE*       fp;
    bool        print_prompt;

    if (my_rank == 0) Open_input_file(&fp, &print_prompt);

    Get_scalar_input(print_prompt, fp);

    Get_neuron_types(fp,  print_prompt);

    /* Find parameters that are specific to each process */
    Compute_my_params();

#   ifdef DEBUG
    printf("Process %d > My first variable is %d\n",
            my_rank, my_first_var);
    printf("Process %d > My last variable is %d\n",
            my_rank, my_last_var);
    printf("Process %d > Num_my_equations is %d\n",
            my_rank, num_my_equations);
    printf("Process %d > My first neuron is %d\n",
            my_rank, my_first_neuron);
    printf("Process %d > My last neuron is %d\n",
            my_rank, my_last_neuron);
    printf("Process %d > Num_my_neurons is %d\n",
            my_rank, num_my_neurons);
#   endif
   

    /* Allocate a global vector for y_init                     */
    /* Note that since synaptic currents are stored in a       */
    /* separate array, we allocate storage for number of vars  */
    /* - total_neurons                                         */
    *y_init = (double *) malloc(sizeof(double)*total_equations);
    Check_malloc(*y_init, "y_init");
    *syn_init = (double *) malloc(sizeof(double)*total_neurons);
    Check_malloc(*syn_init, "syn_init");
    Get_y_init(print_prompt, fp, *y_init, *syn_init);


    /* Allocate storage for synaptic info */
    syn_connections = (syn_conn_t*) malloc(sizeof(syn_conn_t)*
                       total_neurons);
    Check_malloc(syn_connections, "syn_connections");
    Get_syn_info(fp);

    if ((my_rank == 0) && (!print_prompt)) fclose(fp);

    return;
}  /* Read_data */


/*==================================================================*/
/* Only called by process 0 */
void Open_input_file(
         FILE**  fp            /* OUT */,
         bool*   print_prompt  /* OUT */) {

    char filename[MAX_NAME];
    char which_one;

    printf("Dear user:  Data file or standard in? ['d' or 's']\n");
    Skip_white_space();
    nscanf(stdin, "%c", &which_one);
    if (which_one == 's') {
        *print_prompt = true;
        *fp = stdin;
    } else {
        *print_prompt = false;
        printf("What's the file name?\n");
        nscanf(stdin, "%s", filename);
        *fp = fopen(filename,"r");
    }

    if (!fp) {
        fprintf(stderr,
            "Process 0 > Attempt to open %s failed.\n", filename);
        fflush(stderr);
        MPI_Abort(MPI_COMM_WORLD, -1);
    }

}  /* Open_input_file */


/*==================================================================*/
void Get_scalar_input(
         bool     print_prompt     /* IN  */,
         FILE*    fp               /* IN  */) {

    if (my_rank == 0) {
        if (print_prompt) {
            printf("Please enter t_0, h, N (number of timesteps), ");
            printf("m (number of equations),\n");
            printf("the number of neurons, the number of rows,\n ");
            printf("the number of columns, the print frequency,\n");
            printf("and the output filename\n");
        }

        nscanf(fp, "%lf", &time_init);
        nscanf(fp, "%lf", &step_size);
        nscanf(fp, "%d", &num_steps);
        nscanf(fp, "%d", &total_equations);
        nscanf(fp, "%d", &total_neurons);
        nscanf(fp, "%d", &row_count);
        nscanf(fp, "%d", &col_count);
        nscanf(fp, "%d", &print_freq);
        nscanf(fp, "%s", output_filename);

#ifdef DEBUG
        printf("Process 0 > t_0 = %f\n", time_init);
        printf("Process 0 > h = %f\n", step_size);
        printf("Process 0 > num_steps = %d\n", num_steps);
        printf("Process 0 > total_equations = %d\n", total_equations);
        printf("Process 0 > total_neurons = %d\n", total_neurons);
        printf("Process 0 > row_count = %d\n", row_count);
        printf("Process 0 > col_count = %d\n", col_count);
        printf("Process 0 > print_freq = %d\n", print_freq);
        printf("Process 0 > filename = %s\n", output_filename);
        fflush(stdout);
#endif
        total_equations = total_equations - total_neurons;

        if ( (total_neurons % num_processes) != 0 ) {
            fprintf(stderr, "Process 0 > Neuron count must ");
            fprintf(stderr, "be evenly divisible by process count.\n");
            fflush(stderr);
            MPI_Abort(MPI_COMM_WORLD, -1);
        }

    }   /* if (my_rank == 0) */

    MPI_Bcast(&time_init,       1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
    MPI_Bcast(&step_size,       1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
    MPI_Bcast(&num_steps,       1, MPI_INT,    0, MPI_COMM_WORLD);
    MPI_Bcast(&total_equations, 1, MPI_INT,    0, MPI_COMM_WORLD);
    MPI_Bcast(&total_neurons,   1, MPI_INT,    0, MPI_COMM_WORLD);
    MPI_Bcast(&row_count,       1, MPI_INT,    0, MPI_COMM_WORLD);
    MPI_Bcast(&col_count,       1, MPI_INT,    0, MPI_COMM_WORLD);
    MPI_Bcast(&print_freq,      1, MPI_INT,    0, MPI_COMM_WORLD);
    MPI_Bcast(output_filename, MAX_NAME, MPI_CHAR,   0, MPI_COMM_WORLD);

#   ifdef DEBUG
    printf("Process %d > t_0 = %f\n", my_rank, time_init);
    printf("Process %d > h = %f\n", my_rank, step_size);
    printf("Process %d > num_steps = %d\n", my_rank, num_steps);
    printf("Process %d > total_equations = %d\n", 
         my_rank, total_equations);
    printf("Process %d > total_neurons = %d\n", 
         my_rank, total_neurons);
    printf("Process %d > row_count = %d\n", my_rank, row_count);
    printf("Process %d > col_count = %d\n", my_rank, col_count);
    printf("Process %d > print_freq = %d\n", my_rank, print_freq);
    printf("Process %d > filename   = %s\n", my_rank, output_filename);
    fflush(stdout);
#   endif

}  /* Get_scalar_input */


/*==================================================================*/
void Get_neuron_types(
         FILE*  fp             /* IN  */,
         bool   print_prompt   /* IN  */) {

    int i;

    neuron_types = (int*) malloc(total_neurons*sizeof(int));
    Check_malloc(neuron_types, "neuron_types");

    if (my_rank == 0) {
        if (print_prompt)
            printf("Please enter the types of the neurons\n");

        for (i = 0; i < total_neurons; i++) {
            nscanf(fp, "%d", neuron_types + i);
#           ifdef DEBUG
            printf("neuron_types[%d] = %d ", i, neuron_types[i]);
#           endif
        }
#       ifdef DEBUG
        printf("\n");
        fflush(stdout);
#       endif
    }

    MPI_Bcast(neuron_types, total_neurons, MPI_INT, 0, MPI_COMM_WORLD);

#   ifdef DEBUG
    Print_local_int_vector("neuron_types", neuron_types, 
        total_neurons, my_rank);
    fflush(stdout);
#   endif

}  /* Get_neuron_types */


/*==================================================================*/
void Get_y_init (
         bool     print_prompt     /* IN  */,
         FILE*    fp               /* IN  */,
         double   y_init[]         /* OUT */,
         double   syn_init[]       /* OUT */) {

    int i;
    int neuron;
    int eqn_count;  /* keep track of position in y_init */

    if (my_rank == 0) {
        if (print_prompt)
            printf("Please enter y_0\n");

        eqn_count = 0;
        for (neuron = 0; neuron < total_neurons; neuron++) {
            for (i = 0; i < EQUATION_COUNTS[neuron_types[neuron]]; 
                 i++) {
                nscanf(fp, "%lf", &(y_init[eqn_count]));
#               ifdef DEBUG
                printf("y_0[%d] = %f ", eqn_count, y_init[eqn_count]);
#               endif
                eqn_count++;
            }  /* for i */
            nscanf(fp, "%lf", &(syn_init[neuron]));
#           ifdef DEBUG
            printf("syn_0[%d] = %f ", neuron, syn_init[neuron]);
#           endif
        }  /* for neuron */
#       ifdef DEBUG
        printf("\n");
        fflush(stdout);
#       endif
    }  /* my_rank == 0 */

    MPI_Bcast(y_init, total_equations, MPI_DOUBLE, 0, MPI_COMM_WORLD);
    MPI_Bcast(syn_init, total_neurons, MPI_DOUBLE, 0, MPI_COMM_WORLD);

#   ifdef DEBUG
    Print_local_vector("y_0", y_init, total_equations, my_rank);
    Print_local_vector("syn_0", syn_init, total_neurons, my_rank);
    fflush(stdout);
#   endif

}  /* Get_y_init */


/*==================================================================*/
void Get_syn_info(
         FILE*       fp                 /* IN  */) {

    int*       temp_dimension;
    int*       my_dimensions;
    int        i;
    int        max_dim;


    my_dimensions = (int*) malloc(num_my_neurons*sizeof(int));
    Check_malloc(my_dimensions, "my_dimensions");
    if (my_rank == 0) {
        temp_dimension = (int*) malloc(sizeof(int)*total_neurons);
        Check_malloc(temp_dimension, "temp_dimension"); 
    }
    
    Get_dimensions(fp, temp_dimension, my_dimensions, &max_dim);

    /* Store syn_connections.dimension and allocate */
    /* space for the links and types                */
    for (i = 0; i < num_my_neurons; i++) {
        syn_connections[i].dimension = my_dimensions[i];

        if ( my_dimensions[i] > 0 ) {
            syn_connections[i].links = (int*) malloc(sizeof(int)*
                 my_dimensions[i]);
            sprintf(array_name,"syn_connections[%d].links", i);
            Check_malloc(syn_connections[i].links, array_name);
            syn_connections[i].signs = (int*)malloc(sizeof(int)*
                 my_dimensions[i]);
            sprintf(array_name,"syn_connections[%d].signs", i);
            Check_malloc(syn_connections[i].signs, array_name);
        }
    }  /* for i */

    Get_links_and_types(max_dim, temp_dimension, my_dimensions, fp);

#   ifdef DEBUG 
{
    int j;
    for(i = 0; i < num_my_neurons; i++) {
        for(j = 0; j < syn_connections[i].dimension; j++) {
            printf("Process %d > syn_connections[%d].links[%d] = %d\n",
                my_rank, i, j, syn_connections[i].links[j]);
            printf("Process %d > syn_connections[%d].signs[%d] = %d\n",
                my_rank, i, j, syn_connections[i].signs[j]);
        }
    }
}
#   endif

    free(my_dimensions);
    if (my_rank == 0) free(temp_dimension);
}  /* Get_syn_info */


/*==================================================================*/
void Get_dimensions(
         FILE*  fp                  /* IN  */,
         int    temp_dimension[]    /* OUT */,
         int    my_dimensions[]     /* OUT */,
         int*   max_dim             /* OUT */) {

    int i;

    if (my_rank == 0) {
        *max_dim = 0;
        for (i = 0; i < total_neurons; i++) {
           nscanf(fp, "%d", &temp_dimension[i]);
           if (temp_dimension[i] > *max_dim) 
               *max_dim = temp_dimension[i];
#          ifdef DEBUG
           printf("Process %d > temp_dimension[%d] = %d \n",
               my_rank, i, temp_dimension[i]); 
#          endif
        }
#       ifdef DEBUG
        fflush(stdout);
#       endif
    }  /* my_rank == 0 */

    MPI_Scatter(temp_dimension, num_my_neurons, MPI_INT,
        my_dimensions, num_my_neurons, MPI_INT, 0, comm);

#   ifdef DEBUG
    Print_local_int_vector("my_dimensions", my_dimensions, 
        num_my_neurons, my_rank);
#   endif
}  /* Get_dimensions */


/*==================================================================*/
void Get_links_and_types(
         int         max_dim            /* IN  */,
         int         temp_dimension[]   /* IN  */,
         int         my_dimensions[]    /* IN  */,
         FILE*       fp                 /* IN  */) {

    int  neuron;
    int  i, j;
    int  process;
    int  first_neuron;
    int  last_neuron;

    int* temp_links;
    int* temp_types;

    MPI_Status status;

    if (my_rank == 0) {
        /* First get those belonging to process 0 */
        for (neuron = 0; neuron < num_my_neurons; neuron++) {
            for (j = 0; j < temp_dimension[neuron]; j++) {
                nscanf(fp, "%d %d",
                    &(syn_connections[neuron].links[j]),
                    &(syn_connections[neuron].signs[j]) );
#               ifdef DEBUG
                printf("Process %d > For neuron %d: ", my_rank, neuron);
                printf("links[%d] = %d, types[%d] = %d\n",
                        j, syn_connections[neuron].links[j],
                        j, syn_connections[neuron].signs[j]);
#               endif
            }  /* for j */
        } /* for neuron */


        /* Now get links and types for other processes */
        temp_links = (int*) malloc(sizeof(int)*max_dim);
        Check_malloc(temp_links, "temp_links");
        temp_types = (int*) malloc(sizeof(int)*max_dim);
        Check_malloc(temp_types, "temp_types");

        for (process = 1; process < num_processes; process++) {
            first_neuron = process*num_my_neurons;
            /* Assumes each process has the same number of neurons! */
            last_neuron = first_neuron + num_my_neurons;

            for (neuron = first_neuron; neuron < last_neuron; neuron++){
                if (temp_dimension[neuron] > 0) {
                    /* Read in the links and types for the neuron */
                    for (j = 0; j < temp_dimension[neuron]; j++) {
                        nscanf(fp, "%d %d", &temp_links[j],
                            &temp_types[j]);
#                       ifdef DEBUG
                        printf("Process %d > For neuron %d: ",
                            my_rank, neuron);
                        printf("temp_links[%d] = %d ", j, 
                            temp_links[j]);
                        printf("temp_types[%d] = %d\n", j, 
                            temp_types[j]);
#                       endif
                    }
#                   ifdef DEBUG
                    printf("\n");
                    fflush(stdout);
#                   endif

                    /* Send links and types */
                    MPI_Send(temp_links, temp_dimension[neuron],
                        MPI_INT, process, 0, comm);
                    MPI_Send(temp_types, temp_dimension[neuron],
                        MPI_INT, process, 0, comm);
                }  /* temp_dimension > 0 */
            }  /* for neuron */
        }  /* for process */

        free(temp_links);
        free(temp_types);

    } else {   /* my_rank != 0 */

        for (i = 0; i < num_my_neurons; i++) {
            if ( my_dimensions[i] > 0 ) {
                MPI_Recv(syn_connections[i].links, my_dimensions[i],
                    MPI_INT, 0, 0, comm, &status);
                MPI_Recv(syn_connections[i].signs, my_dimensions[i],
                    MPI_INT, 0, 0, comm, &status);
            }
        }
    }  /* my_rank != 0 */

}  /* Get_links_and_types */


/*================================================================*/
void Compute_my_params(void) {

    int i;

    /* Parity of total_neurons checked by Process 0 on input */
    num_my_neurons = total_neurons/num_processes;
    my_first_neuron = my_rank*num_my_neurons;
    my_last_neuron = my_first_neuron + num_my_neurons - 1;

    my_first_var = 0;
    for (i = 0; i < my_first_neuron; i++)
        my_first_var += EQUATION_COUNTS[neuron_types[i]];

    num_my_equations = 0;
    for (i = my_first_neuron; i <= my_last_neuron; i++)
        num_my_equations += EQUATION_COUNTS[neuron_types[i]];

    my_last_var = my_first_var + num_my_equations - 1;

}  /* Compute_my_params */


/*================================================================*/
void Skip_white_space(void) {
    char c;

    c = getchar();
    while ((c == ' ') || (c == '\n') || (c == '\t'))
        c = getchar();
    ungetc(c, stdin);
}  /* Skip_white_space */
