import os
from parser.CellData import CellData
import numpy as np
from fitting.ModelFit import ModelFit, get_best_fit
# from plottools.axes import labelaxes_params
import matplotlib.pyplot as plt
from run_Fitter import iget_start_parameters
colors = ["black", "red", "blue", "orange", "green"]


def main():
    # sam_tests()
    # cells = 40
    number = len([i for i in iget_start_parameters()])
    single_core = number * 1400 / 60 / 60
    print("start parameters:", number)
    print("single core time:", single_core, "h")
    print("single core time:", single_core/24, "days")

    cores = 16
    cells = 40

    print(cores, "core time:", single_core/cores, "h")
    print(cores, "core time:", single_core / 24 / cores, "days")
    print(cores, "core time all", cells, "cells:", single_core / 24 / cores * cells, "days")

    print("left over:", number%cores)

    # fit = get_best_fit("results/final_sam2/2012-12-20-ae-invivo-1/")
    # fit.generate_master_plot()


def sam_tests():
    data_folder = "./data/final/"
    for cell in sorted(os.listdir(data_folder)):
        print(cell)
        cell_folder = os.path.join(data_folder, cell)
        if not os.path.exists(os.path.join(cell_folder, "samspikes1.dat")):
            continue

        cell_data = CellData(cell_folder)
        sampling_rate = int(round(1 / cell_data.get_sampling_interval()))
        sam_spikes = cell_data.get_sam_spiketimes()
        delta_freqs = cell_data.get_sam_delta_frequencies()

        [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = cell_data.get_sam_traces()
        print(len(time_traces))
        for i in range(len(delta_freqs)):

            fig, axes = plt.subplots(2, 1, sharex="all")

            axes[0].plot(time_traces[i], local_eod_traces[i])
            axes[0].set_title("Local EOD - dF {}".format(delta_freqs[i]))
            axes[1].plot(time_traces[i], v1_traces[i])
            axes[1].set_title("v1 trace")
            ah_spike = average_spike_height(sam_spikes, v1_traces[i], sampling_rate)
            for j, idx in enumerate(get_x_best(ah_spike)):
                axes[1].eventplot(sam_spikes[idx], lineoffsets=max(v1_traces[i] + 1.5 * (j + 1)),
                                  colors=colors[j % len(colors)])
            plt.show()
            plt.close()
        break


def average_spike_height(spike_trains, local_eod, sampling_rate):
    average_height = []
    for spikes_train in spike_trains:
        indices = np.array([s * sampling_rate for s in spikes_train[0]], dtype=np.int)
        local_eod = np.array(local_eod)
        spike_values = [local_eod[i] for i in indices if i < len(local_eod)]
        average_height.append(np.mean(spike_values))

    return average_height


def get_x_best(average_heights, x=5):
    biggest_idx = []
    biggest_heights = []

    for i, height in enumerate(average_heights):

        if len(biggest_idx) < x:
            biggest_idx.append(i)
            biggest_heights.append(height)
        elif height > min(biggest_heights):
            mini = np.argmin(biggest_heights)
            biggest_heights[mini] = height
            biggest_idx[mini] = i

    biggest_heights, biggest_idx = (list(t) for t in zip(*sorted(zip(biggest_heights, biggest_idx), reverse=True)))
    print(biggest_heights)
    return biggest_idx


if __name__ == '__main__':
    main()