import os   #compability with windows
from IPython import embed
import numpy as np

def step_response(t, a1, a2, tau1, tau2):
    r_step = (a1*(1 - np.exp(-t/tau1))) + (a2*(1 - np.exp(-t/tau2)))
    r_step[t<0] = 0
    return r_step

def sin_response(t, f, p, A):
    r_sin = A*sin(2*np.pi*t*f + p)
    return r_sin

def parse_dataset(dataset_name):
    assert(os.path.exists(dataset_name))        #see if data exists
    f = open(dataset_name, 'r')                 #open data we gave in
    lines = f.readlines()                       #read data
    f.close()                                   #?

    # metadata lists for every loop
    eodfs = []
    deltafs = []
    stimulusfs = []
    duration = []
    pause = []

    # data itself
    times = []
    frequencies = []
    amplitudes = []

    # temporary lists with data we put in the lists above
    time = []
    ampl = []
    freq = []

    for i in range(len(lines)):
        l = lines[i].strip()                                        #all lines of textdata, exclude all empty lines (empty () default for spacebar)
        if "#" in l and "EODf" in l:                                #if line starts with # EODf:
            eodfs.append(float(l.split(':')[-1].strip()[:-2]))      #append: line splitted by ':' the 2nd part ([-1],
        if "#" in l and "Delta f" in l:                             #which got striped so we sure there is no space at the end,
            deltafs.append(float(l.split(':')[-1].strip()[:-2]))    #from that all expect the last two signs (Hz unit)
        if "#" in l and "StimulusFrequency" in l:                   #this for different metadata in different lists
            stimulusfs.append(float(l.split(':')[-1].strip()[:-2]))
        if "#" in l and "Duration" in l:
            duration.append(float(l.split(':')[-1].strip()[:-3]))
        if "#" in l and "Pause" in l:
            pause.append(float(l.split(':')[-1].strip()[:-3]))

        if '#Key' in l:
            if len(time) != 0:              #therefore empty in the first round
                times.append(np.array(time))          #2nd loop means time != 0, so we put the times/amplitudes/frequencies to
                amplitudes.append(np.array(ampl))     #the data of the first loop
                frequencies.append(np.array(freq))

            time = []                       #temporary lists to overwrite the lists with the same name we made before
            ampl = []                       #so they are empty again
            freq = []

        if len(l) > 0 and l[0] is not '#':              #line not empty and doesnt start with #
            temporary = list(map(float, l.split()))     #temporary list where we got 3 index splitted by spacebar, map to find them
            time.append(temporary[0])                   #temporary lists with the data at that place, respectively
            freq.append(temporary[1])
            ampl.append(temporary[2])

    times.append(np.array(time))          #append data from one list to another
    amplitudes.append(np.array(ampl))     #these append the data from the first loop to the final lists, because we overwrite them (?)
    frequencies.append(np.array(freq))

    return frequencies, times, amplitudes, eodfs, deltafs, stimulusfs, duration, pause      #output of the function

def parse_infodataset(dataset_name):
    assert(os.path.exists(dataset_name))        #see if data exists
    f = open(dataset_name, 'r')                 #open data we gave in
    lines = f.readlines()                       #read data
    f.close()                                   #?

    identifier = []
    for i in range(len(lines)):
        l = lines[i].strip()                                        #all lines of textdata, exclude all empty lines (empty () default for spacebar)
        if "#" in l and "Identifier" in l:
            identifier.append((l.split(':')[-1].strip()[1:12]))
    return identifier

def mean_traces(start, stop, timespan, frequencies, time):
    minimumt = min([len(time[k]) for k in range(len(time))])
    # new time with wished timespan because it varies for different loops
    tnew = np.arange(start, stop, timespan / minimumt)  # 3rd input is stepspacing:
                                                        # in case complete measuring time devided by total number of datapoints
    # interpolation
    #new array with frequencies of both loops as two lists put together
    frequency = np.zeros((len(frequencies), len(tnew)))
    for k in range(len(frequencies)):
        ft = time[k][frequencies[k] > -5]
        fn = frequencies[k][frequencies[k] > -5]
        frequency[k,:] = np.interp(tnew, ft, fn)

    #making a mean over both loops with the axis 0 (=averaged in y direction, axis=1 would be over x axis)
    mf = np.mean(frequency, axis=0)

    return mf, tnew

def mean_noise_cut(frequencies, time, n):
    cutf = []
    cutt = []
    for k in np.arange(0, len(frequencies), n):
        t = time[k]
        f = np.mean(frequencies[k:k+n])
        cutf.append(f)
        cutt.append(t)
    return cutf, cutt

def norm_function(cf_arr, ct_arr, onset_point, offset_point):
    onset_end = onset_point - 10
    offset_start = offset_point - 10

    base = np.median(cf_arr[(ct_arr >= onset_end) & (ct_arr < onset_point)])

    ground = cf_arr - base

    jar = np.median(ground[(ct_arr >= offset_start) & (ct_arr < offset_point)])

    norm = ground / jar
    return norm

def base_eod(frequencies, time, onset_point):
    base_eod = []

    onset_end = onset_point - 10

    base = np.median(frequencies[(time >= onset_end) & (time < onset_point)])
    base_eod.append(base)
    return base_eod


def JAR_eod(frequencies, time, offset_point):
    jar_eod = []

    offset_start = offset_point - 10

    jar = np.median(frequencies[(time >= offset_start) & (time < offset_point)])
    jar_eod.append(jar)

    return jar_eod

def sort_values(values):
    a = values[:2]
    tau = np.array(sorted(values[2:], reverse=False))
    values = np.array([a, tau])
    values_flat = values.flatten()
    return values_flat