import numpy as np
import nixio as nix
import os
from numpy.core.fromnumeric import repeat
from traitlets.traitlets import Instance
from chirp_ams import get_signals
from model import simulate, load_models
from IPython import embed
import matplotlib.pyplot as plt
import multiprocessing
from joblib import Parallel, delayed

data_folder = "data"


def append_settings(section, sec_name, sec_type, settings):
    section = section.create_section(sec_name, sec_type)
    for k in settings.keys():
        if isinstance(settings[k], dict):
            append_settings(section, k, "settings", settings[k])
        else:
            if isinstance(settings[k], np.ndarray):
                if len(settings[k].shape) == 1:
                    section[k] = list(settings[k])
            else:
                section[k] = settings[k]

    
def save(filename, name, stimulus_settings, model_settings, self_signal, other_signal, self_freq, other_freq, complete_stimulus, responses, overwrite=False):
    if os.path.exists(filename) and not overwrite:
        nf = nix.File.open(filename, nix.FileMode.ReadWrite)
    else:
        nf = nix.File.open(filename, mode=nix.FileMode.Overwrite,
                               compression=nix.Compression.DeflateNormal)
    
    if name in nf.blocks:
        print("Data with this name is already stored! ", name)
        nf.close()
        return
    mdata = nf.create_section(name, "nix.simulation")
    append_settings(mdata, "model parameter", "nix.model.settings", model_settings)
    append_settings(mdata, "stimulus parameter", "nix.stimulus.settings", stimulus_settings)
    b = nf.create_block(name, "nix.simulation")
    b.metadata = mdata
    
    # save stimulus
    stim_da = b.create_data_array("complete stimulus", "nix.timeseries.sampled.stimulus", dtype=nix.DataType.Float,
                                  data=complete_stimulus)
    stim_da.label = "voltage"
    stim_da.label = "mV/cm"
    dim = stim_da.append_sampled_dimension(model_settings["deltat"])
    dim.label = "time"
    dim.unit = "s"
    
    self_freq_da = None
    if self_freq is not None:
        self_freq_da = b.create_data_array("self frequency", "nix.timeseries.sampled.frequency", dtype=nix.DataType.Float,
                                            data=self_freq)
        self_freq_da.label = "frequency"
        self_freq_da.label = "Hz"
        dim = self_freq_da.append_sampled_dimension(model_settings["deltat"])
        dim.label = "time"
        dim.unit = "s"
    
    other_freq_da = None
    if other_freq is not None:
        other_freq_da = b.create_data_array("other frequency", "nix.timeseries.sampled", dtype=nix.DataType.Float,
                                            data=other_freq)
        self_freq_da.label = "frequency"
        self_freq_da.label = "Hz"
        dim = other_freq_da.append_sampled_dimension(model_settings["deltat"])
        dim.label = "time"
        dim.unit = "s"
    
    # save responses
    for i in range(len(responses)):
        da = b.create_data_array("response_%i" %i, "nix.timeseries.events.spike_times", 
                                 dtype=nix.DataType.Float, data=responses[i])
        da.label = "time"
        da.unit = "s"
        dim = da.append_alias_range_dimension()
    
    # bind it all together
    tag = b.create_tag("chirp stimulation", "nix.stimulus", [0.0])
    tag.extent = [stim_da.shape[0] * stim_da.dimensions[0].sampling_interval]
    for da in b.data_arrays:
        if "response" in da.name.lower():
            tag.references.append(da)
    tag.create_feature(stim_da, nix.LinkType.Untagged)
    if self_freq_da is not None:
        tag.create_feature(self_freq_da, nix.LinkType.Untagged)
    if other_freq_da is not None:
        tag.create_feature(other_freq_da, nix.LinkType.Untagged)
    nf.close()


def save_baseline_response(filename, block_name, spikes, model_settings, overwrite=False):
    if os.path.exists(filename) and not overwrite:
        nf = nix.File.open(filename, nix.FileMode.ReadWrite)
    else:
        nf = nix.File.open(filename, mode=nix.FileMode.Overwrite,
                           compression=nix.Compression.DeflateNormal)
    if block_name in nf.blocks:
        print("Data with this name is already stored! ", block_name)
        nf.close()
        return
    mdata = nf.create_section(block_name, "nix.simulation")
    append_settings(mdata, "model parameter", "nix.model.settings", model_settings)
    b = nf.create_block(block_name, "nix.simulation")
    b.metadata = mdata
    da = b.create_data_array("baseline_response", "nix.timeseries.events.spike_times", 
                             dtype=nix.DataType.Float, data=spikes)
    da.label = "time"
    da.unit = "s"
    dim = da.append_alias_range_dimension()
    nf.close()


def get_baseline_response(model_params, duration=10):
    eodf = model_params["EODf"]
    dt = model_params["deltat"] 
    cell_params = model_params.copy()
    del cell_params["cell"]
    del cell_params["EODf"]
    time, pre_stim = get_pre_stimulus(eodf)
    baseline_time = np.arange(0.0, duration, dt)
    eod =  np.sin(baseline_time * eodf * 2 * np.pi)
    stim = np.hstack((pre_stim, eod))
    v_0 = np.random.rand(1)[0]
    cell_params["v_zero"] = v_0 
    spikes = simulate(stim, **cell_params)
    spikes = spikes[spikes > time[-1]] - time[-1]
    return spikes


def get_pre_stimulus(eodf, duration=2, df=20, contrast=20, dt=1./20000): 
    time = np.arange(0.0, duration, dt)
    eod = np.sin(time * eodf * 2 * np.pi)
    eod_2 = np.sin(time * (eodf + df) * 2 * np.pi) * contrast/100
    eod_2 *= np.hanning(len(eod_2))
    stim = eod + eod_2
    return time, stim


def simulate_responses(stimulus_params, model_params, repeats=10, deltaf=20):
    cell_params = model_params.copy()
    cell = cell_params["cell"]
    filename = os.path.join(data_folder, "cell_%s.nix" % cell)
    del cell_params["cell"]
    del cell_params["EODf"]
    conditions = ["other", "self"]
    chirp_size = stimulus_params["chirp_size"]
    pre_time, pre_stim = get_pre_stimulus(stimulus_params["eodfs"]["self"], dt=model_params["deltat"])
    for contrast in stimulus_params["contrasts"]:
        params = stimulus_params.copy()
        params["contrast"] = contrast
        del params["contrasts"]
        del params["chirp_frequency"]
          
        for condition in conditions:
            print("\tcontrast: %s, condition: %s" %(contrast, condition), " "*10, end="\r")
            block_name = "contrast_%.3f_condition_%s_deltaf_%i_chirpsize_%i" %(contrast, condition, deltaf, chirp_size)
            params["condition"] = condition
            time, self_signal, self_freq, other_signal, other_freq = get_signals(**params)
            full_signal = (self_signal + other_signal)
            
            spikes = []
            no_other_spikes = []
            for _ in range(repeats):
                v_0 = np.random.rand(1)[0]
                cell_params["v_zero"] = v_0
                sp = simulate(np.hstack((pre_stim, full_signal)), **cell_params)
                spikes.append(sp[sp > pre_time[-1]] - pre_time[-1])
                if condition == "self":
                    v_0 = np.random.rand(1)[0]
                    cell_params["v_zero"] = v_0
                    sp = simulate(np.hstack((pre_stim, self_signal)), **cell_params)
                    no_other_spikes.append(sp[sp > pre_time[-1]] - pre_time[-1])
            if condition == "self":
                name = "contrast_%.3f_condition_no-other_deltaf_%i_chirpsize_%i" %(contrast, deltaf, chirp_size)
                save(filename, name, params, cell_params, self_signal, None, self_freq, None, self_signal, no_other_spikes)
            save(filename, block_name, params, cell_params, self_signal, other_signal, self_freq, other_freq, full_signal, spikes)
    print("\n")


def simulate_cell(cell_id, models):
    deltafs = [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200]  # Hz, difference frequency between self and other
    chirp_sizes = [40, 60, 100]
    stimulus_params = { "eodfs": {"self": 0.0, "other": 0.0}, # eod frequency in Hz, to be overwritten
                        "contrasts": [20, 10, 5, 2.5, 1.25, 0.625, 0.3125],
                        "chirp_size": 100,  # Hz, frequency excursion
                        "chirp_duration": 0.015,  # s, chirp duration
                        "chirp_amplitude_dip": 0.05,  # %, amplitude drop during chirp
                        "chirp_frequency": 5,  # Hz, how often does the fish chirp
                        "duration": 5.,  # s, total duration of simulation
                        "dt": 1,  # s, stepsize of the simulation, to be overwritten
                        }
    model_params = models[cell_id]
    baseline_spikes = get_baseline_response(model_params, duration=30)
    filename = os.path.join(data_folder, "cell_%s.nix" % model_params["cell"])
    save_baseline_response(filename, "baseline response", baseline_spikes, model_params)

    print("Cell: %s" % model_params["cell"])
    for cs in chirp_sizes:
        for deltaf in deltafs:
            stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf}        
            stimulus_params["dt"] = model_params["deltat"]
            stimulus_params["chirp_size"] = cs
            
            print("\t Deltaf: %i" % deltaf)
            chirp_times = np.arange(stimulus_params["chirp_duration"],
                                    stimulus_params["duration"] - stimulus_params["chirp_duration"], 
                                    1./stimulus_params["chirp_frequency"])
            stimulus_params["chirp_times"] = chirp_times     
            simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf)


def main():
    models = load_models("models.csv")
    num_cores = multiprocessing.cpu_count() - 6
    
    Parallel(n_jobs=num_cores)(delayed(simulate_cell)(cell_id, models) for cell_id in range(len(models[:20])))       


if __name__ == "__main__":
    main()