import os
import pyrelacs.DataLoader as dl
import numpy as np
import matplotlib.pyplot as plt
from warnings import warn

def get_subfolder_paths(basepath):
    subfolders = []
    for content in os.listdir(basepath):
        content_path = basepath + content
        if os.path.isdir(content_path):
            subfolders.append(content_path)

    return sorted(subfolders)


def get_traces(directory, trace_type, repro):
    # trace_type = 1: Voltage p-unit
    # trace_type = 2: EOD
    # trace_type = 3: local EOD ~(EOD + stimulus)
    # trace_type = 4: Stimulus

    load_iter = dl.iload_traces(directory, repro=repro)

    time_traces = []
    value_traces = []

    nothing = True

    for info, key, time, x in load_iter:
        nothing = False
        time_traces.append(time)
        value_traces.append(x[trace_type-1])

    if nothing:
        print("iload_traces found nothing for the BaselineActivity repro!")

    return time_traces, value_traces


def get_all_traces(directory, repro):
    load_iter = dl.iload_traces(directory, repro=repro)

    time_traces = []
    v1_traces = []
    eod_traces = []
    local_eod_traces = []
    stimulus_traces = []

    nothing = True

    for info, key, time, x in load_iter:
        nothing = False
        time_traces.append(time)
        v1_traces.append(x[0])
        eod_traces.append(x[1])
        local_eod_traces.append(x[2])
        stimulus_traces.append(x[3])
        print(info)

    traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces]

    if nothing:
        print("iload_traces found nothing for the BaselineActivity repro!")

    return time_traces, traces


def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
    i = 0

    diffs = np.diff(sorted(intensities))
    margin = np.mean(diffs) * 0.6666

    while True:
        if i >= len(intensities):
            break
        intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin)
        i += 1

        # Sort the lists so that intensities are increasing
        x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
        intensities = x[0]
        spiketimes = x[1]

    return intensities, spiketimes, trans_amplitudes


def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin):
    intensity = intensities[index]

    indices_to_merge = []
    for i in range(index+1, len(intensities)):
        if np.abs(intensities[i]-intensity) < margin:
            indices_to_merge.append(i)

    if len(indices_to_merge) != 0:
        indices_to_merge.reverse()

        trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge]

        all_the_same = True
        for j in range(1, len(trans_amplitude_values)):
            if not trans_amplitude_values[0] == trans_amplitude_values[j]:
                all_the_same = False
                break

        if all_the_same:
            for idx in indices_to_merge:
                del trans_amplitudes[idx]
        else:
            raise RuntimeError("Trans_amplitudes not the same....")
        for idx in indices_to_merge:
            spiketimes[index].extend(spiketimes[idx])
            del spiketimes[idx]
            del intensities[idx]

    return intensities, spiketimes, trans_amplitudes


def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval):
    times = []
    mean_frequencies = []

    for i in range(len(spiketimes)):
        trial_times = []
        trial_means = []
        for j in range(len(spiketimes[i])):
            time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval)
            trial_means.append(isi_freq)
            trial_times.append(time)

        time, mean_freq = calculate_mean_frequency(trial_times, trial_means)
        times.append(time)
        mean_frequencies.append(mean_freq)

    return times, mean_frequencies


# TODO remove additional time vector calculation!
def calculate_isi_frequency(spiketimes, time_start, sampling_interval):
    first_isi = spiketimes[0] - time_start
    isis = [first_isi]
    isis.extend(np.diff(spiketimes))
    time = np.arange(time_start, spiketimes[-1], sampling_interval)

    full_frequency = []
    i = 0
    for isi in isis:
        if isi == 0:
            warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()")
            continue
        freq = 1 / isi
        frequency_step = int(round(isi * (1 / sampling_interval))) * [freq]
        full_frequency.extend(frequency_step)
        i += 1
    if len(full_frequency) != len(time):
        if abs(len(full_frequency) - len(time)) == 1:
            warn("FiCurve:__calculate_mean_isi_frequency__():\nFrequency and time were one of in length!")
            if len(full_frequency) < len(time):
                time = time[:len(full_frequency)]
            else:
                full_frequency = full_frequency[:len(time)]
        else:
            print("ERROR PRINT:")
            print("freq:", len(full_frequency), "time:", len(time), "diff:", len(full_frequency) - len(time))
            raise RuntimeError("FiCurve:__calculate_mean_isi_frequency__():\n"
                               "Frequency and time are not the same length!")

    return time, full_frequency


def calculate_mean_frequency(trial_times, trial_freqs):
    lengths = [len(t) for t in trial_times]
    shortest = min(lengths)

    time = trial_times[0][0:shortest]
    shortend_freqs = [freq[0:shortest] for freq in trial_freqs]
    mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)]

    return time, mean_freq


def crappy_smoothing(signal:list, window_size:int = 5) -> list:
    smoothed = []

    for i in range(len(signal)):
        k = window_size
        if i < window_size:
            k = i
        j = window_size
        if i + j > len(signal):
            j = len(signal) - i

        smoothed.append(np.mean(signal[i-k:i+j]))

    return smoothed


def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None):
    contrast = cell_data.get_fi_contrasts()
    time_axes = cell_data.get_time_axes_mean_frequencies()
    mean_freqs = cell_data.get_mean_isi_frequencies()

    if indices is None:
        indices = np.arange(len(contrast))

    for i in indices:
        plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2)))

    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path + "mean_frequency_curves.png")
    plt.close()