import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
import functions as fu
import helperFunctions as hF

from CellData import CellData
from Baseline import BaselineCellData
from FiCurve import FICurveCellData
import Figure_constants as consts
MODEL_COLOR = "orange"
DATA_COLOR = "blue"

DATA_SAVE_PATH = "data/figure_data/"


def main():
    # data_isi_histogram()
    # data_mean_freq_step_stimulus_examples()
    # data_mean_freq_step_stimulus_with_detections()
    # data_fi_curve()
    p_unit_example()
    fi_point_detection()
    pass


def p_unit_example():
    # cell = "data/final/2013-04-17-ac-invivo-1"
    cell = "data/final/2012-04-20-af-invivo-1"
    cell_data = CellData(cell)

    base = BaselineCellData(cell_data)
    base.load_values(cell_data.get_data_path())
    print("burstiness of example cell:", base.get_burstiness())
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), save_dir=cell_data.get_data_path())
    step = cell_data.get_sampling_interval()

    # Overview figure for p-unit behaviour

    fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM)
    gs = gridspec.GridSpec(3, 2)


    # a bit of trace with detected spikes
    ax = fig.add_subplot(gs[0, :])

    v1 = cell_data.get_base_traces(cell_data.V1)[0]
    time = cell_data.get_base_traces(cell_data.TIME)[0]
    spiketimes = cell_data.get_base_spikes()[0]
    start = 0
    duration = 0.10

    ax.plot(np.array(time[:int(duration/step)]) - start, v1[:int(duration/step)])
    ax.eventplot([s for s in spiketimes if start < s < start + duration], lineoffsets=max(v1[:int(duration/step)])+0.75, color="black")
    ax.set_ylabel('Voltage in mV')
    ax.set_xlabel('Time in s')
    ax.set_title("Baseline Firing")
    ax.set_xlim((0, duration))


    # ISI-hist
    ax = fig.add_subplot(gs[1, 0])
    eod_period = 1.0 / cell_data.get_eod_frequency()

    isi = np.array(base.get_interspike_intervals()) / eod_period  # ISI in ms
    maximum = max(isi)
    bins = np.arange(0, maximum * 1.01, 0.1)
    ax.hist(isi, bins=bins)
    ax.set_ylabel("Count")
    ax.set_xlabel("ISI in EOD periods")
    ax.set_title("ISI-histogram")


    # Serial correlation
    ax = fig.add_subplot(gs[2, 0])

    sc = base.get_serial_correlation(10)
    ax.plot(range(11), [0 for _ in range(11)], color="darkgrey", alpha=0.8)
    ax.plot(range(11), [1] + list(sc))



    ax.set_xlabel("Lag")
    ax.set_ylabel("SC")
    ax.set_title("Serial Correlation")
    ax.set_ylim((-1, 1))
    ax.set_xlim((0, 10))
    ax.set_xticks([0, 2, 4, 6, 8, 10])
    # ax.set_xticklabels([0, 2, 4, 6, 8, 10])


    # FI-Curve trace
    ax = fig.add_subplot(gs[1, 1])

    f_trace_times, f_traces = fi.get_mean_time_and_freq_traces()

    part = 1 + 0.2 + 0.2  # stim duration + delay up front and a part of the "delay" at the back
    idx = int(part/step)

    ax.plot(f_trace_times[-1][:idx], f_traces[-1][:idx])
    strength = 200
    smoothed = np.convolve(f_traces[-1][:idx], np.ones(strength)/strength)
    ax.plot(f_trace_times[-1][:idx], smoothed[int(strength/2):idx + int(strength/2)])
    ax.set_xlim((-0.2, part-0.2))
    ylim = ax.get_ylim()
    ax.set_ylim((0, ylim[1]))
    ax.set_xlabel("Time in s")
    ax.set_ylabel("Frequency in Hz")
    ax.set_title("Step Response")


    # FI-Curve
    ax = fig.add_subplot(gs[2, 1])

    contrasts = fi.stimulus_values
    f_zeros = fi.get_f_zero_frequencies()
    f_infties = fi.get_f_inf_frequencies()

    ax.plot(contrasts, f_zeros, 'x')
    ax.plot(contrasts, f_infties, 'o')

    x_values = np.arange(min(contrasts), max(contrasts) + 0.0001, (max(contrasts)-min(contrasts)) / 1000)
    f_zero_fit = [fu.full_boltzmann(x, fi.f_zero_fit[0], fi.f_zero_fit[1], fi.f_zero_fit[2], fi.f_zero_fit[3]) for x in x_values]
    f_inf_fit = [fu.clipped_line(x, fi.f_inf_fit[0], fi.f_inf_fit[1]) for x in x_values]
    ax.plot(x_values, f_zero_fit)
    ax.plot(x_values, f_inf_fit)

    # ax.set_xlim((0, 10))
    # ax.set_ylim((-1, 1))
    ax.set_xlabel("Contrast")
    ax.set_ylabel("Frequency in Hz")
    ax.set_title("FI-Curve")

    plt.tight_layout()
    plt.savefig("thesis/figures/p_unit_example.png")
    plt.close()


def fi_point_detection():
    # cell = "data/final/2013-04-17-ac-invivo-1"
    cell = "data/final/2012-04-20-af-invivo-1"
    cell_data = CellData(cell)

    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())
    step = cell_data.get_sampling_interval()

    fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_SMALL_WIDE, sharey="row")

    f_trace_times, f_traces = fi.get_mean_time_and_freq_traces()

    part = 1 + 0.2 + 0.2  # stim duration + delay up front and a part of the "delay" at the back
    idx = int(part / step)
    f_zero = fi.get_f_zero_frequencies()[-1]
    f_zero_idx = fi.indices_f_zero[-1]

    f_inf = fi.get_f_inf_frequencies()[-1]
    f_inf_idx = fi.indices_f_inf[-1]

    axes[0].plot(f_trace_times[-1][:idx], f_traces[-1][:idx])
    axes[0].plot([f_trace_times[-1][idx] for idx in f_zero_idx], (f_zero, ), "o")
    axes[0].plot([f_trace_times[-1][idx] for idx in f_inf_idx], (f_inf, f_inf), color="orange", linewidth=4)

    # mark stim start and end:
    stim_start = cell_data.get_stimulus_start()
    stim_end = cell_data.get_stimulus_end()
    axes[0].plot([stim_start]*2, (0, fi.get_f_baseline_frequencies()[0]), color="darkgrey")
    axes[0].plot([stim_end]*2, (0, fi.get_f_baseline_frequencies()[0]), color="darkgrey")
    axes[0].set_xlim((-0.2, part - 0.2))
    ylimits = axes[0].get_ylim()

    axes[0].set_xlabel("Time in s")
    axes[0].set_ylabel("Frequency in Hz")
    axes[0].set_title("Step Response")

    contrasts = fi.stimulus_values
    f_zeros = fi.get_f_zero_frequencies()
    f_infties = fi.get_f_inf_frequencies()

    axes[1].plot(contrasts, f_zeros, 'x')
    axes[1].plot(contrasts, f_infties, 'o')

    x_values = np.arange(min(contrasts), max(contrasts) + 0.0001, (max(contrasts) - min(contrasts)) / 1000)
    f_zero_fit = [fu.full_boltzmann(x, fi.f_zero_fit[0], fi.f_zero_fit[1], fi.f_zero_fit[2], fi.f_zero_fit[3]) for x in
                  x_values]
    f_inf_fit = [fu.clipped_line(x, fi.f_inf_fit[0], fi.f_inf_fit[1]) for x in x_values]
    axes[1].plot(x_values, f_zero_fit)
    axes[1].plot(x_values, f_inf_fit)

    axes[1].set_xlabel("Contrast in %")
    # axes[1].set_ylabel("Frequency in Hz")
    axes[1].set_title("FI-Curve")
    axes[1].set_ylim((0, ylimits[1]))


    plt.tight_layout()
    plt.savefig("thesis/figures/f_point_detection.png")
    plt.close()

def data_fi_curve():
    cell = "data/final/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())
    fi.plot_fi_curve()


def data_mean_freq_step_stimulus_with_detections():
    cell = "data/final/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())

    mean_times, mean_freqs = fi.get_mean_time_and_freq_traces()
    idx = -1
    time = np.array(mean_times[idx])
    freq = np.array(mean_freqs[idx])
    f_inf = fi.f_inf_frequencies[idx]
    f_zero = fi.f_zero_frequencies[idx]

    plt.plot(time, freq, color=DATA_COLOR)
    plt.plot(time[freq == f_zero][0], f_zero, "o", color="black")
    f_inf_time = time[(0.2 < time) & (time < 0.4)]
    plt.plot(f_inf_time, [f_inf for _ in f_inf_time], color="black")
    plt.xlim((-0.1, 0.6))
    plt.show()


def data_mean_freq_step_stimulus_examples():
    # todo smooth! add f_0, f_inf, f_base to it?
    cell = "data/invivo/2013-04-17-ac-invivo-1/"
    cell_data = CellData(cell)
    fi = FICurveCellData(cell_data, cell_data.get_fi_contrasts())

    time_traces, freq_traces = fi.get_time_and_freq_traces()
    mean_times, mean_freqs = fi.get_mean_time_and_freq_traces()

    used_idicies = (0, 7, -1)
    fig, axes = plt.subplots(len(used_idicies), figsize=(8, 12), sharex=True, sharey=True)
    for ax_idx, idx in enumerate(used_idicies):
        sv = fi.stimulus_values[idx]

        # for j in range(len(time_traces[i])):
        #     axes[i].plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.5)

        axes[ax_idx].plot(mean_times[idx], mean_freqs[idx], color=DATA_COLOR)
        # plt.plot(mean_times[i], mean_freqs[i], color="black")

        axes[ax_idx].set_ylabel("Frequency [Hz]")
        axes[ax_idx].set_xlim((-0.2, 0.6))
        axes[ax_idx].set_title("Contrast {:.2f} ({:} trials)".format(sv, len(time_traces[idx])))
    axes[ax_idx].set_xlabel("Time [s]")
    plt.show()


def data_isi_histogram(recalculate=True):
    # if isis loadable - load
    name = "isi_cell_data.npy"
    path = os.path.join(DATA_SAVE_PATH, name)
    if os.path.exists(path) and not recalculate:
        isis = np.load(path)
        print("loaded")
    else:
        # if not get them from the cell
        cell = "data/invivo/2013-04-17-ac-invivo-1/"  # not bursty
        # cell = "data/invivo/2014-12-03-ad-invivo-1/"  # half bursty
        # cell = "data/invivo/2015-01-20-ad-invivo-1/"  # does triple peaks...
        # cell = "data/invivo/2018-05-08-ae-invivo-1/"  # a bit bursty
        # cell = "data/invivo/2013-04-10-af-invivo-1/"  # a bit bursty
        cell_data = CellData(cell)
        base = BaselineCellData(cell_data)
        isis = np.array(base.get_interspike_intervals())
        # base.plot_baseline(position=0,time_length=10)

        # save isis
        np.save(path, isis)
    isis = isis * 1000
    # plot histogram
    bins = np.arange(0, 30.1, 0.1)
    plt.hist(isis, bins=bins, color=DATA_COLOR)
    plt.xlabel("Inter spike intervals [ms]")
    plt.ylabel("Count")
    plt.tight_layout()

    plt.show()


if __name__ == '__main__':
    main()