import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
from my_util import functions as fu

from parser.CellData import CellData
from experiments.Baseline import BaselineCellData
from experiments.FiCurve import FICurveCellData, FICurveModel
import Figure_constants as consts
from fitting.ModelFit import get_best_fit

EXAMPLE_CELL = "data/final/2012-12-20-ac-invivo-1"


def main():
    # data_isi_histogram()
    # data_mean_freq_step_stimulus_examples()
    # data_mean_freq_step_stimulus_with_detections()
    # data_fi_curve()

    p_unit_example()
    fi_point_detection()
    p_unit_heterogeneity()

    # test_fi_curve_colors()
    pass


def p_unit_heterogeneity():
    data_dir = "data/final/"
    strong_bursty_cell = "2014-01-10-ae-invivo-1"
    bursty_cell = "2014-03-19-ad-invivo-1"
    non_bursty_cell = "2012-12-21-am-invivo-1"

    cells = [non_bursty_cell, bursty_cell, strong_bursty_cell]

    fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
    gs = gridspec.GridSpec(3, 2, width_ratios=(3, 1))

    # a bit of trace with detected spikes

    for i, cell in enumerate(cells):
        cell_dir = data_dir + cell + "/"
        cell_data = CellData(cell_dir)

        step = cell_data.get_sampling_interval()
        time = cell_data.get_base_traces(cell_data.TIME)[0]
        v1 = cell_data.get_base_traces(cell_data.V1)[0]
        spikes = cell_data.get_base_spikes()[0]

        time_offset = 0
        duration = 0.1

        idx_start = int(np.rint(time_offset / step))
        idx_end = int(np.rint((time_offset + duration) / step))

        ax = fig.add_subplot(gs[i, 0])
        ax.plot(np.array(time[idx_start:idx_end]) * 1000, v1[idx_start:idx_end], color=consts.COLOR_DATA)
        y_lims = ax.get_ylim()
        event_tick_length = (y_lims[1] - y_lims[0]) / 10
        ax.eventplot([s * 1000 for s in spikes if time_offset <= s < time_offset+duration],
                             colors="black", lineoffsets=max(v1[idx_start:idx_end])+1.5, linelengths=event_tick_length)
        ax.set_ylabel("Voltage [mV]")
        ax.set_xlim((0, duration*1000))
        if i == 2:
            ax.set_xlabel("Time [ms]")
            ax.set_yticks([-5, 5, 15])

    for i, cell in enumerate(cells):
        cell_dir = data_dir + cell + "/"
        cell_data = CellData(cell_dir)
        eodf = cell_data.get_eod_frequency()

        cell_isi = BaselineCellData(cell_data).get_interspike_intervals() * eodf
        bins = np.arange(0, 0.025, 0.0001) * eodf
        ax = fig.add_subplot(gs[i, 1])
        ax.hist(cell_isi, bins=bins, density=True, color=consts.COLOR_DATA)
        ax.set_ylabel("Density")
        ax.set_yticklabels(["{:.1f}".format(t) for t in ax.get_yticks()])
        if i == 2:
            ax.set_xlabel("ISI [EOD periods]")

    plt.tight_layout()
    fig.align_ylabels()
    consts.set_figure_labels(xoffset=-2.5)

    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "isi_hist_heterogeneity.pdf", transparent=True)
    plt.close()


def p_unit_example():
    cell = EXAMPLE_CELL
    cell_data = CellData(cell)
    print("p-unit example eodf:", cell_data.get_eod_frequency())
    base = BaselineCellData(cell_data)
    base.load_values(cell_data.get_data_path())
    print("burstiness of example cell:", base.get_burstiness())
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), save_dir=cell_data.get_data_path())
    step = cell_data.get_sampling_interval()

    # Overview figure for p-unit behaviour
    fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_LARGE)
    gs = gridspec.GridSpec(3, 2)

    # a bit of trace with detected spikes
    ax = fig.add_subplot(gs[0, :])

    v1 = cell_data.get_base_traces(cell_data.V1)[0]
    time = cell_data.get_base_traces(cell_data.TIME)[0]
    spiketimes = cell_data.get_base_spikes()[0]
    start = 0
    duration = 0.10

    ax.plot((np.array(time[:int(duration/step)]) - start) * 1000, v1[:int(duration/step)], consts.COLOR_DATA)
    ax.eventplot([s * 1000 for s in spiketimes if start < s < start + duration], lineoffsets=max(v1[:int(duration/step)])+1.25,
                 color="black", linelengths=2)
    ax.set_ylabel('Voltage [mV]')
    ax.set_xlabel('Time [ms]')
    ax.set_title("Baseline Firing")
    ax.set_xlim((0, duration*1000))


    # ISI-hist
    ax = fig.add_subplot(gs[1, 0])
    eod_period = 1.0 / cell_data.get_eod_frequency()

    isi = np.array(base.get_interspike_intervals()) / eod_period  # ISI in ms
    maximum = max(isi)
    bins = np.arange(0, maximum * 1.01, 0.1)
    ax.hist(isi, bins=bins, color=consts.COLOR_DATA, density=True)
    ax.set_ylabel("Density")
    ax.set_xlabel("ISI [EOD periods]")
    ax.set_title("ISI Histogram")


    # Serial correlation
    ax = fig.add_subplot(gs[1, 1])

    sc = base.get_serial_correlation(10)
    ax.plot(range(11), [0 for _ in range(11)], color="darkgrey", alpha=0.8)
    ax.plot(range(11), [1] + list(sc), color=consts.COLOR_DATA)
    ax.plot(range(11), [1] + list(sc), '+', color="black")

    ax.set_xlabel("Lag")
    ax.set_ylabel("SC")
    ax.set_title("Serial Correlation")
    ax.set_ylim((-1, 1))
    ax.set_xlim((0, 10))
    ax.set_xticks([0, 2, 4, 6, 8, 10])
    # ax.set_xticklabels([0, 2, 4, 6, 8, 10])


    # FI-Curve trace
    ax = fig.add_subplot(gs[2, 0])

    f_trace_times, f_traces = fi.get_mean_time_and_freq_traces()

    part = 0.4 + 0.2 + 0.2  # stim duration + delay up front and a part of the "delay" at the back
    idx = int(part/step)

    ax.plot(f_trace_times[-1][:idx], f_traces[-1][:idx], color=consts.COLOR_DATA)
    # strength = 200
    # smoothed = np.convolve(f_traces[-1][:idx], np.ones(strength)/strength)
    # ax.plot(f_trace_times[-1][:idx], smoothed[int(strength/2):idx + int(strength/2)])
    ax.set_xlim((-0.2, part-0.2))
    ylim = ax.get_ylim()
    ax.set_ylim((0, ylim[1]))
    ax.set_xlabel("Time [s]")
    ax.set_ylabel("Frequency [Hz]")
    ax.set_title("Step Response")

    # FI-Curve
    ax = fig.add_subplot(gs[2, 1])

    contrasts = fi.stimulus_values
    f_zeros = fi.get_f_zero_frequencies()
    f_infties = fi.get_f_inf_frequencies()

    ax.plot(contrasts, f_zeros, ',', marker=consts.f0_marker, color=consts.COLOR_DATA_f0)
    ax.plot(contrasts, f_infties, ',', marker=consts.finf_marker, color=consts.COLOR_DATA_finf)

    x_values = np.arange(min(contrasts), max(contrasts) + 0.0001, (max(contrasts)-min(contrasts)) / 1000)
    f_zero_fit = [fu.full_boltzmann(x, fi.f_zero_fit[0], fi.f_zero_fit[1], fi.f_zero_fit[2], fi.f_zero_fit[3]) for x in x_values]
    f_inf_fit = [fu.clipped_line(x, fi.f_inf_fit[0], fi.f_inf_fit[1]) for x in x_values]
    ax.plot(x_values, f_zero_fit, color=consts.COLOR_DATA_f0)
    ax.plot(x_values, f_inf_fit, color=consts.COLOR_DATA_finf)

    # ax.set_xlim((0, 10))
    # ax.set_ylim((-1, 1))
    ax.set_xlabel("Contrast")
    ax.set_ylabel("Frequency [Hz]")
    ax.set_xticks([-0.2, -0.1, 0, 0.1, 0.2])
    ax.set_xlim((-0.21, 0.2))
    ylim = ax.get_ylim()
    ax.set_ylim((0, ylim[1]))
    ax.set_title("f-I Curve")

    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5, yoffset=2.2)
    fig.label_axes()
    plt.savefig("thesis/figures/p_unit_example.pdf", transparent=True)
    plt.close()


def fi_point_detection():
    cell = EXAMPLE_CELL
    cell_data = CellData(cell)

    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())
    step = cell_data.get_sampling_interval()

    fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_MEDIUM_WIDE, sharey="row")

    f_trace_times, f_traces = fi.get_mean_time_and_freq_traces()

    part = 0.4 + 0.2 + 0.2  # stim duration + delay up front and a part of the "delay" at the back
    idx = int(part / step)
    f_zero = fi.get_f_zero_frequencies()[-1]
    f_zero_idx = fi.indices_f_zero[-1]

    f_inf = fi.get_f_inf_frequencies()[-1]
    f_inf_idx = fi.indices_f_inf[-1]
    f_baseline = fi.get_f_baseline_frequencies()[-1]
    f_base_idx = fi.indices_f_baseline[-1]

    axes[0].plot(f_trace_times[-1][:idx], f_traces[-1][:idx], color=consts.COLOR_DATA)
    axes[0].plot([f_trace_times[-1][idx] for idx in f_zero_idx], (f_zero, ), ",", marker=consts.f0_marker, color=consts.COLOR_DATA_f0)
    axes[0].plot([f_trace_times[-1][idx] for idx in f_inf_idx], (f_inf, f_inf), color=consts.COLOR_DATA_finf, linewidth=4)
    axes[0].plot([f_trace_times[-1][idx] for idx in f_base_idx], (f_baseline, f_baseline), color="grey", linewidth=4)

    # mark stim start and end:
    stim_start = cell_data.get_stimulus_start()
    stim_end = cell_data.get_stimulus_end()
    axes[0].plot([stim_start, stim_end], (100, 100), color="black", linewidth=3)
    # axes[0].plot([stim_start]*2, (0, fi.get_f_baseline_frequencies()[0]), color="darkgrey")
    # axes[0].plot([stim_end]*2, (0, fi.get_f_baseline_frequencies()[0]), color="darkgrey")
    axes[0].set_xlim((-0.2, part - 0.2))
    ylimits = axes[0].get_ylim()

    axes[0].set_xlabel("Time [s]")
    axes[0].set_ylabel("Frequency [Hz]")
    axes[0].set_title("Step Response")

    contrasts = fi.stimulus_values
    f_zeros = fi.get_f_zero_frequencies()
    f_infties = fi.get_f_inf_frequencies()

    axes[1].plot(contrasts, f_zeros, ",", marker=consts.f0_marker, color=consts.COLOR_DATA_f0)
    axes[1].plot(contrasts, f_infties, ",", marker=consts.finf_marker, color=consts.COLOR_DATA_finf)

    x_values = np.arange(min(contrasts), max(contrasts) + 0.0001, (max(contrasts) - min(contrasts)) / 1000)
    f_zero_fit = [fu.full_boltzmann(x, fi.f_zero_fit[0], fi.f_zero_fit[1], fi.f_zero_fit[2], fi.f_zero_fit[3]) for x in
                  x_values]
    f_inf_fit = [fu.clipped_line(x, fi.f_inf_fit[0], fi.f_inf_fit[1]) for x in x_values]
    axes[1].plot(x_values, f_zero_fit, color=consts.COLOR_DATA_f0)
    axes[1].plot(x_values, f_inf_fit, color=consts.COLOR_DATA_finf)

    axes[1].set_xlabel("Contrast")
    # axes[1].set_ylabel("Frequency in Hz")
    axes[1].set_title("f-I Curve")
    axes[1].set_ylim((0, ylimits[1]))

    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5)
    fig.label_axes()
    plt.savefig("thesis/figures/f_point_detection.pdf", transparent=True)
    plt.close()


def data_fi_curve():
    cell = "data/final/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())
    fi.plot_fi_curve()


def data_mean_freq_step_stimulus_with_detections():
    cell = "data/final/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())

    mean_times, mean_freqs = fi.get_mean_time_and_freq_traces()
    idx = -1
    time = np.array(mean_times[idx])
    freq = np.array(mean_freqs[idx])
    f_inf = fi.f_inf_frequencies[idx]
    f_zero = fi.f_zero_frequencies[idx]

    plt.plot(time, freq, color=consts.COLOR_DATA)
    plt.plot(time[freq == f_zero][0], f_zero, "o", color="black")
    f_inf_time = time[(0.2 < time) & (time < 0.4)]
    plt.plot(f_inf_time, [f_inf for _ in f_inf_time], color="black")
    plt.xlim((-0.1, 0.6))
    plt.show()


def data_mean_freq_step_stimulus_examples():
    # todo smooth! add f_0, f_inf, f_base to it?
    cell = "data/invivo/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())

    time_traces, freq_traces = fi.get_time_and_freq_traces()
    mean_times, mean_freqs = fi.get_mean_time_and_freq_traces()

    used_idicies = (0, 7, -1)
    fig, axes = plt.subplots(len(used_idicies), figsize=(8, 12), sharex=True, sharey=True)
    for ax_idx, idx in enumerate(used_idicies):
        sv = fi.stimulus_values[idx]

        # for j in range(len(time_traces[i])):
        #     axes[i].plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.5)

        axes[ax_idx].plot(mean_times[idx], mean_freqs[idx], color=consts.COLOR_DATA)
        # plt.plot(mean_times[i], mean_freqs[i], color="black")

        axes[ax_idx].set_ylabel("Frequency [Hz]")
        axes[ax_idx].set_xlim((-0.2, 0.6))
        axes[ax_idx].set_title("Contrast {:.2f} ({:} trials)".format(sv, len(time_traces[idx])))
    axes[ax_idx].set_xlabel("Time [s]")
    plt.show()


def data_isi_histogram(recalculate=True):
    # if isis loadable - load
    name = "isi_cell_data.npy"
    path = os.path.join(consts.SAVE_FOLDER, name)
    if os.path.exists(path) and not recalculate:
        isis = np.load(path)
        print("loaded")
    else:
        # if not get them from the cell
        cell = "data/invivo/2013-04-17-ac-invivo-1/"  # not bursty
        # cell = "data/invivo/2014-12-03-ad-invivo-1/"  # half bursty
        # cell = "data/invivo/2015-01-20-ad-invivo-1/"  # does triple peaks...
        # cell = "data/invivo/2018-05-08-ae-invivo-1/"  # a bit bursty
        # cell = "data/invivo/2013-04-10-af-invivo-1/"  # a bit bursty
        cell_data = CellData(cell)
        base = BaselineCellData(cell_data)
        isis = np.array(base.get_interspike_intervals())
        # base.plot_baseline(position=0,time_length=10)

        # save isis
        np.save(path, isis)
    isis = isis * 1000
    # plot histogram
    bins = np.arange(0, 30.1, 0.1)
    plt.hist(isis, bins=bins, color=consts.COLOR_DATA)
    plt.xlabel("Inter spike intervals [ms]")
    plt.ylabel("Count")
    plt.tight_layout()

    plt.show()


def test_fi_curve_colors():
    example_cell_fit = "results/final_2/2012-12-20-ac-invivo-1"
    cell = EXAMPLE_CELL

    cell_data = CellData(cell)
    fit = get_best_fit(example_cell_fit)

    fig, axes = plt.subplots(1, 3)

    axes[0].set_title("Cell")
    fi_curve = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), save_dir=cell_data.get_data_path())
    contrasts = cell_data.get_fi_contrasts()
    f_zeros = fi_curve.get_f_zero_frequencies()
    f_infs = fi_curve.get_f_inf_frequencies()
    axes[0].plot(contrasts, f_zeros, ',', marker=consts.f0_marker, color=consts.COLOR_DATA_f0)
    axes[0].plot(contrasts, f_infs, ',', marker=consts.finf_marker, color=consts.COLOR_DATA_finf)

    x_values = np.arange(min(contrasts), max(contrasts), (max(contrasts) - min(contrasts)) / 1000)
    f_inf_fit = fi_curve.f_inf_fit
    f_zero_fit = fi_curve.f_zero_fit
    f_zero_fit = [fu.full_boltzmann(x, f_zero_fit[0], f_zero_fit[1], f_zero_fit[2], f_zero_fit[3]) for x in x_values]
    f_inf_fit = [fu.clipped_line(x, f_inf_fit[0], f_inf_fit[1]) for x in x_values]
    axes[0].plot(x_values, f_zero_fit, color=consts.COLOR_DATA_f0)
    axes[0].plot(x_values, f_inf_fit, color=consts.COLOR_DATA_finf)

    axes[2].plot(x_values, f_zero_fit, color=consts.COLOR_DATA_f0)
    axes[2].plot(x_values, f_inf_fit, color=consts.COLOR_DATA_finf)

    axes[1].set_title("Model")
    model = fit.get_model()
    fi_curve = FICurveModel(model, contrasts, eod_frequency=cell_data.get_eod_frequency())

    f_zeros = fi_curve.get_f_zero_frequencies()
    f_infs = fi_curve.get_f_inf_frequencies()
    axes[1].plot(contrasts, f_zeros, ',', marker=consts.f0_marker, color=consts.COLOR_MODEL_f0)
    axes[1].plot(contrasts, f_infs, ',', marker=consts.finf_marker, color=consts.COLOR_MODEL_finf)

    x_values = np.arange(min(contrasts), max(contrasts), (max(contrasts) - min(contrasts)) / 1000)
    f_inf_fit = fi_curve.f_inf_fit
    f_zero_fit = fi_curve.f_zero_fit
    f_zero_fit = [fu.full_boltzmann(x, f_zero_fit[0], f_zero_fit[1], f_zero_fit[2], f_zero_fit[3]) for x in x_values]
    f_inf_fit = [fu.clipped_line(x, f_inf_fit[0], f_inf_fit[1]) for x in x_values]
    axes[1].plot(x_values, f_zero_fit, color=consts.COLOR_MODEL_f0)
    axes[1].plot(x_values, f_inf_fit, color=consts.COLOR_MODEL_finf)

    axes[2].plot(contrasts, f_zeros, ",", marker=consts.f0_marker, color=consts.COLOR_MODEL_f0)
    axes[2].plot(contrasts, f_infs, ",", marker=consts.finf_marker, color=consts.COLOR_MODEL_finf)

    plt.show()
    plt.close()


if __name__ == '__main__':
    main()