from Baseline import get_baseline_class
from CellData import CellData, icelldata_of_dir
from models.LIFACnoise import LifacNoiseModel
from Baseline import BaselineCellData, BaselineModel
from os import listdir
import numpy as np
from IPython import embed
import pyrelacs.DataLoader as Dl
from ModelFit import ModelFit
from FiCurve import FICurveModel, FICurveCellData
import os
import matplotlib.pyplot as plt
import functions as fu
from scipy.optimize import curve_fit
from scipy.signal import find_peaks
from thunderfish.eventdetection import threshold_crossing_times, threshold_crossings, detect_peaks



cells = ["data/invivo/2013-04-16-ad-invivo-1", "data/invivo/2014-01-16-aj-invivo-1", "data/invivo/2015-01-20-ad-invivo-1", "data/invivo_bursty/2012-05-07-aa-invivo-1"]

for cell in cells:
    data = CellData(cell)
    print(len(data.get_base_traces(CellData.V1)))

quit()


def indices_of_peaks_of_distribution(y_values, stepsize_ms, eod_freq):
    eod_freq_ms = eod_freq / 1000
    distance = int(0.75*eod_freq_ms / stepsize_ms)
    print(distance*stepsize_ms)
    peaks, _ = find_peaks(np.array(y_values), distance=distance)

    return peaks


def remove_close_peaks(maxima_idx, peaks, closeness=2):
    to_del = []
    maxima_idx = list(maxima_idx)
    for idx in maxima_idx:
        for i in range(-1*closeness,closeness+1, 1):
            if 0 <= idx + i < len(peaks):
                if peaks[idx+i] > peaks[idx]:
                    to_del.append(idx)
                    break

    for val in to_del:
        maxima_idx.remove(val)

    return maxima_idx


def find_local_maxima(values):
    local_max_idx = []
    for i in range(len(values)):
        maxima = True
        for j in (-1, 1):
            if 0 <= i+j < len(values):
                if values[i+j] > values[i]:
                    maxima = False
                    break
            else:
                continue
        if maxima:
            local_max_idx.append(i)

    return local_max_idx


def rms(array):
    square = np.array(array)**2
    return np.sqrt(np.mean(square))

def perc_smaller_value(isis, value):
    isis = np.array(isis)
    fullfilled = isis < value

    return np.sum(fullfilled) / len(fullfilled)

cell_datas = []  # [CellData("data/invivo/2014-12-03-ad-invivo-1/")]
for cell_data in icelldata_of_dir("data/invivo/", test_for_v1_trace=False):
    cell_datas.append(cell_data)
for cell_data in icelldata_of_dir("data/invivo_bursty/"):
    cell_datas.append(cell_data)

burstiness = []

for cell_data in cell_datas:
    base = BaselineCellData(cell_data)
    burstiness.append(base.get_burstiness())
cell_data_idx = np.arange(0, len(cell_datas), 1)
burstiness, cell_data_idx = (list(t) for t in zip(*sorted(zip(burstiness, cell_data_idx))))



for i in range(len(burstiness)):
    base = BaselineCellData(cell_datas[cell_data_idx[i]])
    isis = np.array(base.get_interspike_intervals()) * 1000
    bins = np.arange(0, 30.1, 0.2)
    plt.hist(isis, bins=bins)
    plt.title(str(burstiness[i]))
    plt.show()


quit()
for cell_data in cell_datas:
    base = BaselineCellData(cell_data)
    isis = np.array(base.get_interspike_intervals()) * 1000
    eod_freq = cell_data.get_eod_frequency()

    bins = np.arange(0, 30.1, 0.2)
    # y_values = plt.hist(isis, bins=bins, cumulative=True, density=True, alpha=0.5)
    # y_values2 = plt.hist(isis, bins=bins, density=True)

    value = perc_smaller_value(isis, 2.5/(eod_freq/1000)) * np.mean(isis)
    dif_mean_median.append(value)
    # plt.title("Diff % < 2.5eod / mean= {:.2f}".format(value))
    # peaks, _ = detect_peaks(y_values[0], 0.5*np.std(y_values[0]))


    # hist_x = bins[peaks]
    # hist_peaks = y_values[0][peaks]
    # plt.plot(hist_x, hist_peaks, '+')
    # plt.plot([2.5/(eod_freq/1000)]*2, (0, 1), ":", color="black")
    # plt.plot([np.median(isis)]*2, (0, 1), "--", color="darkblue")
    # plt.plot([np.mean(isis)]*2, (0, 1), "--", color="darkgreen")
    # plt.plot([rms(isis)]*2, (0, 1), "--", color="red")

    if value < 1:
        cells_sorted["below_one"].append(cell_data)
    elif value < 3:
        cells_sorted["below_three"].append(cell_data)
    else:
        cells_sorted["other"].append(cell_data)

count = 0
for cell_data in cells_sorted["below_one"]:
    count += 1
    if count <= 10:
        base = BaselineCellData(cell_data)
        isis = np.array(base.get_interspike_intervals()) * 1000
        eod_freq = cell_data.get_eod_frequency()
        value = perc_smaller_value(isis, 2.5 / (eod_freq / 1000)) * np.mean(isis)

        bins = np.arange(0, 30.1, 0.2)
        plt.title("Value < 1: {:.2f}".format(value))
        plt.hist(isis, bins=bins, density=True)
        plt.show()
        plt.close()
count = 0
for cell_data in cells_sorted["below_three"]:
    count += 1
    if count <= 10:
        base = BaselineCellData(cell_data)
        isis = np.array(base.get_interspike_intervals()) * 1000
        eod_freq = cell_data.get_eod_frequency()
        value = perc_smaller_value(isis, 2.5 / (eod_freq / 1000)) * np.mean(isis)

        bins = np.arange(0, 30.1, 0.2)
        plt.title("1 < Value < 3: {:.2f}".format(value))
        plt.hist(isis, bins=bins, density=True)
        plt.show()
        plt.close()

count = 0
for cell_data in cells_sorted["other"]:
    count += 1
    if count <= 10:
        base = BaselineCellData(cell_data)
        isis = np.array(base.get_interspike_intervals()) * 1000
        eod_freq = cell_data.get_eod_frequency()
        value = perc_smaller_value(isis, 2.5 / (eod_freq / 1000)) * np.mean(isis)

        bins = np.arange(0, 30.1, 0.2)
        plt.title("Value >=3: {:.2f}".format(value))
        plt.hist(isis, bins=bins, density=True)
        plt.show()
        plt.close()


print("< one:", len(cells_sorted["below_one"]))
print("< three:", len(cells_sorted["below_three"]))
print("< more:", len(cells_sorted["other"]))

quit()


for cell_data in icelldata_of_dir("data/"):
    baseline = get_baseline_class(cell_data)

    baseline.get_burstiness()