import pyrelacs.DataLoader as dl
import os
import numpy as np


def get_subfolder_paths(basepath):
    subfolders = []
    for content in os.listdir(basepath):
        content_path = basepath + content
        if os.path.isdir(content_path):
            subfolders.append(content_path)

    return sorted(subfolders)


def get_traces(directory, trace_type, repro):
    # trace_type = 1: Voltage p-unit
    # trace_type = 2: EOD
    # trace_type = 3: local EOD ~(EOD + stimulus)
    # trace_type = 4: Stimulus

    load_iter = dl.iload_traces(directory, repro=repro)

    time_traces = []
    value_traces = []

    nothing = True

    for info, key, time, x in load_iter:
        nothing = False
        time_traces.append(time)
        value_traces.append(x[trace_type-1])

    if nothing:
        print("iload_traces found nothing for the BaselineActivity repro!")

    return time_traces, value_traces


def get_all_traces(directory, repro):
    load_iter = dl.iload_traces(directory, repro=repro)

    time_traces = []
    v1_traces = []
    eod_traces = []
    local_eod_traces = []
    stimulus_traces = []

    nothing = True

    for info, key, time, x in load_iter:
        nothing = False
        time_traces.append(time)
        v1_traces.append(x[0])
        eod_traces.append(x[1])
        local_eod_traces.append(x[2])
        stimulus_traces.append(x[3])
        print(info)

    traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces]

    if nothing:
        print("iload_traces found nothing for the BaselineActivity repro!")

    return time_traces, traces


def crappy_smoothing(signal:list, window_size:int = 5) -> list:
    smoothed = []

    for i in range(len(signal)):
        k = window_size
        if i < window_size:
            k = i
        j = window_size
        if i + j > len(signal):
            j = len(signal) - i

        smoothed.append(np.mean(signal[i-k:i+j]))

    return smoothed


def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None):
    contrast = cell_data.get_fi_contrasts()
    time_axes = cell_data.get_time_axes_fi_curve_mean_frequencies()
    mean_freqs = cell_data.get_mean_fi_curve_isi_frequencies()

    if indices is None:
        indices = np.arange(len(contrast))

    for i in indices:
        plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2)))

    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path + "mean_frequency_curves.png")
    plt.close()