import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
import functions as fu
import helperFunctions as hF

from CellData import CellData
from Baseline import BaselineCellData
from FiCurve import FICurveCellData, FICurveModel
import Figure_constants as consts
from ModelFit import get_best_fit

EXAMPLE_CELL = "data/final/2012-12-20-ac-invivo-1"

def main():
    # data_isi_histogram()
    # data_mean_freq_step_stimulus_examples()
    # data_mean_freq_step_stimulus_with_detections()
    # data_fi_curve()

    p_unit_example()
    fi_point_detection()
    p_unit_heterogeneity()

    # test_fi_curve_colors()
    pass


def p_unit_heterogeneity():
    data_dir = "data/final/"
    strong_bursty_cell = "2014-01-10-ae-invivo-1"
    bursty_cell = "2014-03-19-ad-invivo-1"
    non_bursty_cell = "2012-12-21-am-invivo-1"

    cells = [non_bursty_cell, bursty_cell, strong_bursty_cell]

    fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
    gs = gridspec.GridSpec(3, 2, width_ratios=(3, 1))

    # a bit of trace with detected spikes

    for i, cell in enumerate(cells):
        cell_dir = data_dir + cell + "/"
        cell_data = CellData(cell_dir)

        step = cell_data.get_sampling_interval()
        time = cell_data.get_base_traces(cell_data.TIME)[0]
        v1 = cell_data.get_base_traces(cell_data.V1)[0]
        spikes = cell_data.get_base_spikes()[0]

        time_offset = 0
        duration = 0.1

        idx_start = int(np.rint(time_offset / step))
        idx_end = int(np.rint((time_offset + duration) / step))

        ax = fig.add_subplot(gs[i, 0])
        ax.plot(np.array(time[idx_start:idx_end]) * 1000, v1[idx_start:idx_end], color=consts.COLOR_DATA)
        y_lims = ax.get_ylim()
        event_tick_length = (y_lims[1] - y_lims[0]) / 10
        ax.eventplot([s * 1000 for s in spikes if time_offset <= s < time_offset+duration],
                             colors="black", lineoffsets=max(v1[idx_start:idx_end])+1.5, linelengths=event_tick_length)
        ax.set_ylabel("Voltage [mV]")
        ax.set_xlim((0, duration*1000))
        if i == 2:
            ax.set_xlabel("Time [ms]")
            ax.set_yticks([-5, 5, 15])

    for i, cell in enumerate(cells):
        cell_dir = data_dir + cell + "/"
        cell_data = CellData(cell_dir)
        eodf = cell_data.get_eod_frequency()

        cell_isi = BaselineCellData(cell_data).get_interspike_intervals() * eodf
        bins = np.arange(0, 0.025, 0.0001) * eodf
        ax = fig.add_subplot(gs[i, 1])
        ax.hist(cell_isi, bins=bins, density=True, color=consts.COLOR_DATA)
        ax.set_ylabel("Density")
        ax.set_yticklabels(["{:.1f}".format(t) for t in ax.get_yticks()])
        if i == 2:
            ax.set_xlabel("ISI [EOD periods]")

    plt.tight_layout()
    fig.align_ylabels()
    consts.set_figure_labels(xoffset=-2.5)

    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "isi_hist_heterogeneity.pdf", transparent=True)
    plt.close()


def p_unit_example():
    cell = EXAMPLE_CELL
    cell_data = CellData(cell)
    print("p-unit example eodf:", cell_data.get_eod_frequency())
    base = BaselineCellData(cell_data)
    base.load_values(cell_data.get_data_path())
    print("burstiness of example cell:", base.get_burstiness())
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), save_dir=cell_data.get_data_path())
    step = cell_data.get_sampling_interval()

    # Overview figure for p-unit behaviour
    fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_LARGE)
    gs = gridspec.GridSpec(3, 2)

    # a bit of trace with detected spikes
    ax = fig.add_subplot(gs[0, :])

    v1 = cell_data.get_base_traces(cell_data.V1)[0]
    time = cell_data.get_base_traces(cell_data.TIME)[0]
    spiketimes = cell_data.get_base_spikes()[0]
    start = 0
    duration = 0.10

    ax.plot((np.array(time[:int(duration/step)]) - start) * 1000, v1[:int(duration/step)], consts.COLOR_DATA)
    ax.eventplot([s * 1000 for s in spiketimes if start < s < start + duration], lineoffsets=max(v1[:int(duration/step)])+1.25,
                 color="black", linelengths=2)
    ax.set_ylabel('Voltage [mV]')
    ax.set_xlabel('Time [ms]')
    ax.set_title("Baseline Firing")
    ax.set_xlim((0, duration*1000))


    # ISI-hist
    ax = fig.add_subplot(gs[1, 0])
    eod_period = 1.0 / cell_data.get_eod_frequency()

    isi = np.array(base.get_interspike_intervals()) / eod_period  # ISI in ms
    maximum = max(isi)
    bins = np.arange(0, maximum * 1.01, 0.1)
    ax.hist(isi, bins=bins, color=consts.COLOR_DATA, density=True)
    ax.set_ylabel("Density")
    ax.set_xlabel("ISI [EOD periods]")
    ax.set_title("ISI Histogram")


    # Serial correlation
    ax = fig.add_subplot(gs[1, 1])

    sc = base.get_serial_correlation(10)
    ax.plot(range(11), [0 for _ in range(11)], color="darkgrey", alpha=0.8)
    ax.plot(range(11), [1] + list(sc), color=consts.COLOR_DATA)
    ax.plot(range(11), [1] + list(sc), '+', color="black")

    ax.set_xlabel("Lag")
    ax.set_ylabel("SC")
    ax.set_title("Serial Correlation")
    ax.set_ylim((-1, 1))
    ax.set_xlim((0, 10))
    ax.set_xticks([0, 2, 4, 6, 8, 10])
    # ax.set_xticklabels([0, 2, 4, 6, 8, 10])


    # FI-Curve trace
    ax = fig.add_subplot(gs[2, 0])

    f_trace_times, f_traces = fi.get_mean_time_and_freq_traces()

    part = 0.4 + 0.2 + 0.2  # stim duration + delay up front and a part of the "delay" at the back
    idx = int(part/step)

    ax.plot(f_trace_times[-1][:idx], f_traces[-1][:idx], color=consts.COLOR_DATA)
    # strength = 200
    # smoothed = np.convolve(f_traces[-1][:idx], np.ones(strength)/strength)
    # ax.plot(f_trace_times[-1][:idx], smoothed[int(strength/2):idx + int(strength/2)])
    ax.set_xlim((-0.2, part-0.2))
    ylim = ax.get_ylim()
    ax.set_ylim((0, ylim[1]))
    ax.set_xlabel("Time [s]")
    ax.set_ylabel("Frequency [Hz]")
    ax.set_title("Step Response")

    # FI-Curve
    ax = fig.add_subplot(gs[2, 1])

    contrasts = fi.stimulus_values
    f_zeros = fi.get_f_zero_frequencies()
    f_infties = fi.get_f_inf_frequencies()

    ax.plot(contrasts, f_zeros, ',', marker=consts.f0_marker, color=consts.COLOR_DATA_f0)
    ax.plot(contrasts, f_infties, ',', marker=consts.finf_marker, color=consts.COLOR_DATA_finf)

    x_values = np.arange(min(contrasts), max(contrasts) + 0.0001, (max(contrasts)-min(contrasts)) / 1000)
    f_zero_fit = [fu.full_boltzmann(x, fi.f_zero_fit[0], fi.f_zero_fit[1], fi.f_zero_fit[2], fi.f_zero_fit[3]) for x in x_values]
    f_inf_fit = [fu.clipped_line(x, fi.f_inf_fit[0], fi.f_inf_fit[1]) for x in x_values]
    ax.plot(x_values, f_zero_fit, color=consts.COLOR_DATA_f0)
    ax.plot(x_values, f_inf_fit, color=consts.COLOR_DATA_finf)

    # ax.set_xlim((0, 10))
    # ax.set_ylim((-1, 1))
    ax.set_xlabel("Contrast")
    ax.set_ylabel("Frequency [Hz]")
    ax.set_xticks([-0.2, -0.1, 0, 0.1, 0.2])
    ax.set_xlim((-0.21, 0.2))
    ylim = ax.get_ylim()
    ax.set_ylim((0, ylim[1]))
    ax.set_title("f-I Curve")

    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5, yoffset=2.2)
    fig.label_axes()
    plt.savefig("thesis/figures/p_unit_example.pdf", transparent=True)
    plt.close()


def fi_point_detection():
    cell = EXAMPLE_CELL
    cell_data = CellData(cell)

    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())
    step = cell_data.get_sampling_interval()

    fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_MEDIUM_WIDE, sharey="row")

    f_trace_times, f_traces = fi.get_mean_time_and_freq_traces()

    part = 0.4 + 0.2 + 0.2  # stim duration + delay up front and a part of the "delay" at the back
    idx = int(part / step)
    f_zero = fi.get_f_zero_frequencies()[-1]
    f_zero_idx = fi.indices_f_zero[-1]

    f_inf = fi.get_f_inf_frequencies()[-1]
    f_inf_idx = fi.indices_f_inf[-1]
    f_baseline = fi.get_f_baseline_frequencies()[-1]
    f_base_idx = fi.indices_f_baseline[-1]

    axes[0].plot(f_trace_times[-1][:idx], f_traces[-1][:idx], color=consts.COLOR_DATA)
    axes[0].plot([f_trace_times[-1][idx] for idx in f_zero_idx], (f_zero, ), ",", marker=consts.f0_marker, color=consts.COLOR_DATA_f0)
    axes[0].plot([f_trace_times[-1][idx] for idx in f_inf_idx], (f_inf, f_inf), color=consts.COLOR_DATA_finf, linewidth=4)
    axes[0].plot([f_trace_times[-1][idx] for idx in f_base_idx], (f_baseline, f_baseline), color="grey", linewidth=4)

    # mark stim start and end:
    stim_start = cell_data.get_stimulus_start()
    stim_end = cell_data.get_stimulus_end()
    axes[0].plot([stim_start, stim_end], (100, 100), color="black", linewidth=3)
    # axes[0].plot([stim_start]*2, (0, fi.get_f_baseline_frequencies()[0]), color="darkgrey")
    # axes[0].plot([stim_end]*2, (0, fi.get_f_baseline_frequencies()[0]), color="darkgrey")
    axes[0].set_xlim((-0.2, part - 0.2))
    ylimits = axes[0].get_ylim()

    axes[0].set_xlabel("Time [s]")
    axes[0].set_ylabel("Frequency [Hz]")
    axes[0].set_title("Step Response")

    contrasts = fi.stimulus_values
    f_zeros = fi.get_f_zero_frequencies()
    f_infties = fi.get_f_inf_frequencies()

    axes[1].plot(contrasts, f_zeros, ",", marker=consts.f0_marker, color=consts.COLOR_DATA_f0)
    axes[1].plot(contrasts, f_infties, ",", marker=consts.finf_marker, color=consts.COLOR_DATA_finf)

    x_values = np.arange(min(contrasts), max(contrasts) + 0.0001, (max(contrasts) - min(contrasts)) / 1000)
    f_zero_fit = [fu.full_boltzmann(x, fi.f_zero_fit[0], fi.f_zero_fit[1], fi.f_zero_fit[2], fi.f_zero_fit[3]) for x in
                  x_values]
    f_inf_fit = [fu.clipped_line(x, fi.f_inf_fit[0], fi.f_inf_fit[1]) for x in x_values]
    axes[1].plot(x_values, f_zero_fit, color=consts.COLOR_DATA_f0)
    axes[1].plot(x_values, f_inf_fit, color=consts.COLOR_DATA_finf)

    axes[1].set_xlabel("Contrast")
    # axes[1].set_ylabel("Frequency in Hz")
    axes[1].set_title("f-I Curve")
    axes[1].set_ylim((0, ylimits[1]))

    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5)
    fig.label_axes()
    plt.savefig("thesis/figures/f_point_detection.pdf", transparent=True)
    plt.close()


def data_fi_curve():
    cell = "data/final/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())
    fi.plot_fi_curve()


def data_mean_freq_step_stimulus_with_detections():
    cell = "data/final/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())

    mean_times, mean_freqs = fi.get_mean_time_and_freq_traces()
    idx = -1
    time = np.array(mean_times[idx])
    freq = np.array(mean_freqs[idx])
    f_inf = fi.f_inf_frequencies[idx]
    f_zero = fi.f_zero_frequencies[idx]

    plt.plot(time, freq, color=consts.COLOR_DATA)
    plt.plot(time[freq == f_zero][0], f_zero, "o", color="black")
    f_inf_time = time[(0.2 < time) & (time < 0.4)]
    plt.plot(f_inf_time, [f_inf for _ in f_inf_time], color="black")
    plt.xlim((-0.1, 0.6))
    plt.show()


def data_mean_freq_step_stimulus_examples():
    # todo smooth! add f_0, f_inf, f_base to it?
    cell = "data/invivo/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())

    time_traces, freq_traces = fi.get_time_and_freq_traces()
    mean_times, mean_freqs = fi.get_mean_time_and_freq_traces()

    used_idicies = (0, 7, -1)
    fig, axes = plt.subplots(len(used_idicies), figsize=(8, 12), sharex=True, sharey=True)
    for ax_idx, idx in enumerate(used_idicies):
        sv = fi.stimulus_values[idx]

        # for j in range(len(time_traces[i])):
        #     axes[i].plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.5)

        axes[ax_idx].plot(mean_times[idx], mean_freqs[idx], color=consts.COLOR_DATA)
        # plt.plot(mean_times[i], mean_freqs[i], color="black")

        axes[ax_idx].set_ylabel("Frequency [Hz]")
        axes[ax_idx].set_xlim((-0.2, 0.6))
        axes[ax_idx].set_title("Contrast {:.2f} ({:} trials)".format(sv, len(time_traces[idx])))
    axes[ax_idx].set_xlabel("Time [s]")
    plt.show()


def data_isi_histogram(recalculate=True):
    # if isis loadable - load
    name = "isi_cell_data.npy"
    path = os.path.join(consts.SAVE_FOLDER, name)
    if os.path.exists(path) and not recalculate:
        isis = np.load(path)
        print("loaded")
    else:
        # if not get them from the cell
        cell = "data/invivo/2013-04-17-ac-invivo-1/"  # not bursty
        # cell = "data/invivo/2014-12-03-ad-invivo-1/"  # half bursty
        # cell = "data/invivo/2015-01-20-ad-invivo-1/"  # does triple peaks...
        # cell = "data/invivo/2018-05-08-ae-invivo-1/"  # a bit bursty
        # cell = "data/invivo/2013-04-10-af-invivo-1/"  # a bit bursty
        cell_data = CellData(cell)
        base = BaselineCellData(cell_data)
        isis = np.array(base.get_interspike_intervals())
        # base.plot_baseline(position=0,time_length=10)

        # save isis
        np.save(path, isis)
    isis = isis * 1000
    # plot histogram
    bins = np.arange(0, 30.1, 0.1)
    plt.hist(isis, bins=bins, color=consts.COLOR_DATA)
    plt.xlabel("Inter spike intervals [ms]")
    plt.ylabel("Count")
    plt.tight_layout()

    plt.show()


def test_fi_curve_colors():
    example_cell_fit = "results/final_2/2012-12-20-ac-invivo-1"
    cell = EXAMPLE_CELL

    cell_data = CellData(cell)
    fit = get_best_fit(example_cell_fit)

    fig, axes = plt.subplots(1, 3)

    axes[0].set_title("Cell")
    fi_curve = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), save_dir=cell_data.get_data_path())
    contrasts = cell_data.get_fi_contrasts()
    f_zeros = fi_curve.get_f_zero_frequencies()
    f_infs = fi_curve.get_f_inf_frequencies()
    axes[0].plot(contrasts, f_zeros, ',', marker=consts.f0_marker, color=consts.COLOR_DATA_f0)
    axes[0].plot(contrasts, f_infs, ',', marker=consts.finf_marker, color=consts.COLOR_DATA_finf)

    x_values = np.arange(min(contrasts), max(contrasts), (max(contrasts) - min(contrasts)) / 1000)
    f_inf_fit = fi_curve.f_inf_fit
    f_zero_fit = fi_curve.f_zero_fit
    f_zero_fit = [fu.full_boltzmann(x, f_zero_fit[0], f_zero_fit[1], f_zero_fit[2], f_zero_fit[3]) for x in x_values]
    f_inf_fit = [fu.clipped_line(x, f_inf_fit[0], f_inf_fit[1]) for x in x_values]
    axes[0].plot(x_values, f_zero_fit, color=consts.COLOR_DATA_f0)
    axes[0].plot(x_values, f_inf_fit, color=consts.COLOR_DATA_finf)

    axes[2].plot(x_values, f_zero_fit, color=consts.COLOR_DATA_f0)
    axes[2].plot(x_values, f_inf_fit, color=consts.COLOR_DATA_finf)

    axes[1].set_title("Model")
    model = fit.get_model()
    fi_curve = FICurveModel(model, contrasts, eod_frequency=cell_data.get_eod_frequency())

    f_zeros = fi_curve.get_f_zero_frequencies()
    f_infs = fi_curve.get_f_inf_frequencies()
    axes[1].plot(contrasts, f_zeros, ',', marker=consts.f0_marker, color=consts.COLOR_MODEL_f0)
    axes[1].plot(contrasts, f_infs, ',', marker=consts.finf_marker, color=consts.COLOR_MODEL_finf)

    x_values = np.arange(min(contrasts), max(contrasts), (max(contrasts) - min(contrasts)) / 1000)
    f_inf_fit = fi_curve.f_inf_fit
    f_zero_fit = fi_curve.f_zero_fit
    f_zero_fit = [fu.full_boltzmann(x, f_zero_fit[0], f_zero_fit[1], f_zero_fit[2], f_zero_fit[3]) for x in x_values]
    f_inf_fit = [fu.clipped_line(x, f_inf_fit[0], f_inf_fit[1]) for x in x_values]
    axes[1].plot(x_values, f_zero_fit, color=consts.COLOR_MODEL_f0)
    axes[1].plot(x_values, f_inf_fit, color=consts.COLOR_MODEL_finf)

    axes[2].plot(contrasts, f_zeros, ",", marker=consts.f0_marker, color=consts.COLOR_MODEL_f0)
    axes[2].plot(contrasts, f_infs, ",", marker=consts.finf_marker, color=consts.COLOR_MODEL_finf)

    plt.show()
    plt.close()


if __name__ == '__main__':
    main()