import numpy as np
import matplotlib.pyplot as plt
import pyrelacs.DataLoader as Dl

test_cell = "data/final/2010-11-08-al-invivo-1/"
test_cell = "data/final/2012-04-20-af-invivo-1/"  # mostly well detected


def main():
    fi_spiketimes = get_unsorted_spiketimes(test_cell + "/fispikes1.dat")
    count = 0
    for info, key, time, x in Dl.iload_traces(test_cell, repro="FICurve", before=0.2, after=0.8):
        # time, v1, eod, local_eod, stimulus
        print(key)
        print(info)
        v1 = x[0]
        # eod = x[1]
        # local_eod = x[2]
        # stimulus = x[3]

        count = 0
        for i in range(len(fi_spiketimes)):
            if count >= 9:
                break
            count += 1
            height = max(v1)+i

            plt.eventplot(np.array(fi_spiketimes[i]), colors="black", lineoffsets=height)
        plt.plot(np.array(time), v1)
        plt.show()


    pass


def get_unsorted_spiketimes(fi_file):
    spiketimes = []
    for metadata, key, data in Dl.iload(fi_file):
        spike_time_data = data[:, 0] / 1000
        spiketimes.append(spike_time_data)
    return spiketimes

def get_fi_curve_spiketimes(fi_file):
        spiketimes = []
        pre_intensities = []
        pre_durations = []
        intensities = []
        trans_amplitudes = []
        pre_duration = -1
        index = -1
        skip = False
        trans_amplitude = float('nan')
        for metadata, key, data in Dl.iload(fi_file):
            if len(metadata) != 0:

                metadata_index = 0

                if '----- Control --------------------------------------------------------' in metadata[0].keys():
                    metadata_index = 1
                    pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
                    trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
                    if pre_duration == 0:
                        skip = False
                    else:
                        skip = True
                        continue
                else:
                    if "preduration" in metadata[0].keys():
                        pre_duration = float(metadata[0]["preduration"][:-2])
                        trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
                        if pre_duration == 0:
                            skip = False
                        else:
                            skip = True
                            continue

                if skip:
                    continue
                if 'intensity' in metadata[metadata_index].keys():
                    intensity = float(metadata[metadata_index]['intensity'][:-2])
                    pre_intensity = float(metadata[metadata_index]['preintensity'][:-2])
                else:
                    intensity = float(metadata[1-metadata_index]['intensity'][:-2])
                    pre_intensity = float(metadata[1-metadata_index]['preintensity'][:-2])

                intensities.append(intensity)
                pre_durations.append(pre_duration)
                pre_intensities.append(pre_intensity)
                trans_amplitudes.append(trans_amplitude)
                spiketimes.append([])
                index += 1

            if skip:
                continue

            if data.shape[1] != 1:
                raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!")

            spike_time_data = data[:, 0]/1000
            if len(spike_time_data) < 10:
                print("# ignoring spike-train that contains less than 10 spikes.")
                continue
            if spike_time_data[-1] < 1:
                print("# ignoring spike-train that ends before one second.")
                continue

            spiketimes[index].append(spike_time_data)

        # TODO Check if sorting works!
        # new_order = np.arange(0, len(intensities), 1)
        # intensities, new_order = zip(*sorted(zip(intensities, new_order)))
        # intensities = list(intensities)
        # spiketimes = [spiketimes[i] for i in new_order]
        # trans_amplitudes = [trans_amplitudes[i] for i in new_order]
        #
        # for i in range(len(intensities)-1, -1, -1):
        #     if len(spiketimes[i]) < 3:
        #         del intensities[i]
        #         del spiketimes[i]
        #         del trans_amplitudes[i]

        return trans_amplitudes, intensities, spiketimes


if __name__ == '__main__':
    main()