import helperFunctions as hf
from CellData import icelldata_of_dir
import functions as fu
import numpy as np
import time
import matplotlib.pyplot as plt
import os
from scipy.signal import argrelmax
from thunderfish.eventdetection import detect_peaks
from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus
from models.LIFACnoise import LifacNoiseModel
from FiCurve import FICurveModel, get_fi_curve_class
from Baseline import get_baseline_class
from AdaptionCurrent import Adaption
from stimuli.StepStimulus import StepStimulus
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus


def time_test_function():
    for n in [1000]:  # number of calls
        print("calls:", n)
        start = time.time()
        for i in range(n):
            data = np.random.normal(size=10000)
            y = [fu.rectify(x) for x in data]
        end = time.time()
        print("time:", end - start)


def test_cell_data():
    for cell_data in icelldata_of_dir("../data/"):
        #if "2012-12-20-ad" not in cell_data.get_data_path():
        #    continue
        print()
        print(cell_data.get_data_path())
        if len(cell_data.get_base_traces(cell_data.TIME)) != 0:
            # print("works!")
            #print("VS:", cell_data.get_vector_strength())
            #print("SC:", cell_data.get_serial_correlation(5))
            print("Eod freq:", cell_data.get_eod_frequency())
        else:
            pass
            #print("NNNOOOOOOOOOO!")
            #print("spiketimes:", len(cell_data.get_base_spikes()))
            #print("Times:", len(cell_data.get_base_traces(cell_data.TIME)))
            #print("EOD:", len(cell_data.get_base_traces(cell_data.EOD)))


def test_peak_detection():
    for cell_data in icelldata_of_dir("../data/"):
        print()
        print(cell_data.get_data_path())
        times = cell_data.get_base_traces(cell_data.TIME)
        eod = cell_data.get_base_traces(cell_data.EOD)
        v1 = cell_data.get_base_traces(cell_data.V1)
        for i in range(len(v1)):
            pieces = 20
            v1_trace = v1[i]
            total = len(v1_trace)
            all_peaks = []
            plt.plot(times[i], v1[i])

            for n in range(pieces):
                length = int(total/pieces)
                first_index = n*length
                last_index = (n+1)*length
                std = np.std(v1_trace[first_index:last_index])
                peaks, _ = detect_peaks(v1_trace[first_index:last_index], std * 3)
                peaks = peaks + first_index
                all_peaks.extend(peaks)
                plt.plot(times[i][first_index], v1_trace[first_index], 'o', color="green")

            all_peaks = np.array(all_peaks)

            plt.plot(times[i][all_peaks], v1[i][all_peaks], 'o', color='red')

            plt.show()


def test_simulation_speed():
    parameters = {'mem_tau': 21.348990483539083, 'delta_a': 20.41809814660199, 'input_scaling': 3.0391541280864196, 'v_offset': 26.25, 'threshold': 1, 'v_base': 0, 'step_size': 0.00005, 'tau_a': 158.0404259501454, 'a_zero': 0, 'v_zero': 0, 'noise_strength': 2.87718460648148}
    model = LifacNoiseModel(parameters)
    repetitions = 1
    seconds = 10
    stimulus = SinusAmplitudeModulationStimulus(750, 1, 10, 1, 8)
    time_start = 0
    t_start = time.time()
    for i in range(repetitions):

        v, spikes = model.simulate(stimulus, seconds, time_start)
        print(len(v))
        print(len(spikes))

        #time_v = np.arange(time_start, seconds, model.get_sampling_interval())
        #plt.plot(time_v, v, '.')
        #plt.show()
        #freq = hf.mean_freq_of_spiketimes_after_time_x(spikes, parameters["step_size"], 0)
        #print(freq)
    t_end = time.time()
    #print("baseline markers:", model.calculate_baseline_markers(750, 3))
    print("took:", round((t_end-t_start)/repetitions, 5), "seconds for " + str(seconds) + "s simulation", "step size:", parameters["step_size"]*1000, "ms")


def test_fi_curve_class():
    model_parameters = {'v_offset': -15.234375, 'input_scaling': 64.94152780134829, 'step_size': 5e-05, 'a_zero': 2,
                        'threshold': 1, 'v_base': 0, 'delta_a': 0.04763179657857666, 'tau_a': 0.07891848949732623,
                        'mem_tau': 0.004828473985707999, 'noise_strength': 0.017132801387559883,
                        'v_zero': 0, 'dend_tau': 0.0015230454266819539}

    model = LifacNoiseModel(model_parameters)
    contrasts = np.arange(-0.4, 0.4, 0.05)

    ficurve = get_fi_curve_class(model, contrasts, 700)
    ficurve.plot_mean_frequency_curves()
    return

    for cell_data in icelldata_of_dir("../data/"):
        fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts())
        fi_curve.plot_mean_frequency_curves()
        # fi_curve.plot_f_point_detections()


    pass


def test_adaption_class():
    model_parameters = {'v_offset': -15.234375, 'input_scaling': 64.94152780134829, 'step_size': 5e-05, 'a_zero': 2,
                         'threshold': 1, 'v_base': 0, 'delta_a': 0.04763179657857666, 'tau_a': 0.07891848949732623,
                         'mem_tau': 0.004828473985707999, 'noise_strength': 0.017132801387559883,
                         'v_zero': 0, 'dend_tau': 0.0015230454266819539}

    model = LifacNoiseModel(model_parameters)
    contrasts = np.arange(-0.4, 1, 0.05)
    for delta_a in np.arange(0.1, 1.5, 0.1):
        model.set_variable("delta_a", delta_a)
    # for tau_a in np.arange(0.01, 0.1, 0.01):
        # model.set_variable("tau_a", tau_a)
        fi_curve = FICurveModel(model, contrasts, 750, 10)
        adaption = Adaption(fi_curve)
        # adaption.plot_exponential_fits()
        m_tau = model.get_parameters()["tau_a"]
        approx_tau = adaption.get_tau_real()
        m_delta_a = model.get_parameters()["delta_a"]
        approx_delta_a = adaption.get_delta_a()
        fi_curve.plot_fi_curve("../figures/error_plots/adaption_test_{:.2f}_delta_a_with_{:.2f}_error.png".format(delta_a, approx_delta_a/ m_delta_a))
        # print("model tau_a  \t: {:.4f} vs {:.4f} adaption estimate, error: {:.2}".format(m_tau, approx_tau, (approx_tau / m_tau)))
        print("model delta_a\t: {:.4f} vs {:.4f} adaption estimate, error: {:.2}".format(m_delta_a, approx_delta_a, (approx_delta_a / m_delta_a)))
        print(fi_curve.f_zero_fit[3])



    quit()
    for cell_data in icelldata_of_dir("../data/"):
        print()
        print(cell_data.get_data_path())
        fi_curve = FICurve(cell_data)
        adaption = Adaption(fi_curve)

        adaption.plot_exponential_fits()

        print("tau_effs:", adaption.get_tau_effs())
        print("tau_real:", adaption.get_tau_real())
        fi_curve.plot_fi_curve()


def test_parameters():
    parameters = {'mem_tau': 21., 'delta_a': 0.1, 'input_scaling': 400.,
                  'v_offset': 85.25, 'threshold': 0.1, 'v_base': 0, 'step_size': 0.00005, 'tau_a': 0.01,
                  'a_zero': 0, 'v_zero': 0, 'noise_strength': 3}
    model = LifacNoiseModel(parameters)

    base_stimulus_freq = 350
    stimulus = SinusAmplitudeModulationStimulus(base_stimulus_freq, 1.2, 5, 5, 20)

    plot_model_during_stimulus(model, stimulus, 30)

    bf, vs, sc = model.calculate_baseline_markers(base_stimulus_freq)
    contrasts = [0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3]
    modulation_frequency = 1
    f_infs, f_inf_slope = model.calculate_fi_markers(contrasts, base_stimulus_freq, modulation_frequency)

    print("Baseline frequency: {:.2f}".format(bf))
    print("Vector strength:    {:.2f}".format(vs))
    print("serial correlation: {:.2f}".format(sc[0]))

    print("f infinity slope:   {:.2f}".format(f_inf_slope))
    print("f infinities: \n", f_infs)


def test_vector_strength_calculation():
    model = LifacNoiseModel({"noise_strength": 0})


    bf, vs1, sc = model.calculate_baseline_markers(600)

    base_stim = SinusAmplitudeModulationStimulus(600, 0, 0)
    _, spiketimes = model.simulate(base_stim, 30)
    stimulus_trace = base_stim.as_array(0, 30, model.get_sampling_interval())
    time_trace = np.arange(0, 30, model.get_sampling_interval())

    vs2 = hf.calculate_vector_strength_from_spiketimes(time_trace, stimulus_trace, spiketimes, model.get_sampling_interval())

    print("with assumed eod durations  vs: {:.3f}".format(vs1))
    print("with detected eod durations vs: {:.3f}".format(vs2))


def test_baseline_polar_plot():

    model_parameter = {'v_offset': -15.234375, 'input_scaling': 64.94152780134829, 'step_size': 5e-05, 'a_zero': 2,
                         'threshold': 1, 'v_base': 0, 'delta_a': 0.04763179657857666, 'tau_a': 0.07891848949732623,
                         'mem_tau': 0.004828473985707999, 'noise_strength': 0.017132801387559883,
                         'v_zero': 0, 'dend_tau': 0.0015230454266819539}

    baseline = get_baseline_class(LifacNoiseModel(model_parameter), 700)
    baseline.plot_polar_vector_strength()

    # for data in icelldata_of_dir("../data/"):
    #     trace = data.get_base_traces(trace_type=data.V1)
    #     if len(trace) == 0:
    #         print("NO V1 TRACE FOUND")
    #         continue
    #
    #     baseline = get_baseline_class(data)
    #     baseline.plot_polar_vector_strength()



def plot_model_during_stimulus(model: LifacNoiseModel, stimulus:SinusAmplitudeModulationStimulus, total_time):
    _, spiketimes = model.simulate(stimulus, total_time)

    time = np.arange(0, total_time, model.get_sampling_interval())
    fig, axes = plt.subplots(5, 1, figsize=(9, 4*2),  sharex="all")

    stimulus_array = stimulus.as_array(0, total_time, model.get_sampling_interval())
    axes[0].plot(time, stimulus_array)
    axes[0].set_title("Stimulus")
    axes[1].plot(time, rectify_stimulus_array(stimulus_array))
    axes[1].set_title("rectified Stimulus")
    axes[2].plot(time, model.get_voltage_trace())
    axes[2].set_title("Voltage")
    axes[3].plot(time, model.get_adaption_trace())
    axes[3].set_title("Adaption")

    f_time, f = hf.calculate_time_and_frequency_trace(spiketimes, model.get_sampling_interval())
    axes[4].plot(f_time, f)

    axes[4].set_title("Frequency")
    plt.show()


def rectify_stimulus_array(stimulus_array: np.ndarray):
    return np.array([x if x > 0 else 0 for x in stimulus_array])


if __name__ == '__main__':

    model_parameters = {'v_offset': -15.234375, 'input_scaling': 64.94152780134829, 'step_size': 5e-05, 'a_zero': 2,
                        'threshold': 1, 'v_base': 0, 'delta_a': 0.04763179657857666, 'tau_a': 0.07891848949732623,
                        'mem_tau': 0.004828473985707999, 'noise_strength': 0.017132801387559883,
                        'v_zero': 0, 'dend_tau': 0.0015230454266819539}

    model = LifacNoiseModel(model_parameters)

    base = get_baseline_class(model, 650)
    base.plot_baseline()
    fi = get_fi_curve_class(model, np.arange(0, 2, 0.2), 650)
    fi.plot_fi_curve()

    quit()

    # test_baseline_polar_plot()
    # time_test_function()
    test_cell_data()
    # test_peak_detection()
    # test_simulation_speed()
    # test_parameters()
    test_fi_curve_class()
    # test_adaption_class()
    # test_vector_strength_calculation()
    pass