import numpy as np
import os


def load_chirp_spikes(dataset):
    spikes_file = os.path.join(dataset, "chirpspikess1.dat")
    if not os.path.exists(spikes_file):
        print("found no chirps!")
        return {}
    with open(spikes_file, 'r') as f:
        lines = f.readlines()
    spikes = {}
    for l in lines:
        l = l.strip()
        if "index" in l and "chirp" not in l:
            index = int(l.split(":")[-1])
        if "deltaf" in l and "true" not in l:
           df = l.split(":")[-1]
        if "contrast" in l and "true" not in l:
            contrast = l.split(":")[-1]
        if "chirpsize" in l:
            cs = l.split(":")[-1]
        if "#Key" in l:
            spikes[(index, df, contrast, cs)] = {}
        if "chirp index" in l:
            ci = int(l.split(":")[-1])
        if "beat phase" in l:
            phase = float(l.split(":")[-1])
            spikes[(index, df, contrast, cs)][(ci, phase)] = []
        if len(l.strip()) != 0 and "#" not in l:
            spikes[(index, df, contrast, cs)][(ci, phase)].append(float(l))
    return spikes


def load_chirp_eod(dataset):
    eod_file = os.path.join(dataset, "chirpeodampls.dat")
    if not os.path.exists(eod_file):
        print("found no chirpeodampls.dat!")
        return {}
    with open(eod_file, 'r') as f:
        lines = f.readlines()
    chirp_eod = {}
    for l in lines:
        l = l.strip()
        if "index" in l and "chirp" not in l:
            index = int(l.split(":")[-1])
        if "deltaf" in l and "true" not in l:
           df = l.split(":")[-1]
        if "contrast" in l and "true" not in l:
            contrast = l.split(":")[-1]
        if "chirpsize" in l:
            cs = l.split(":")[-1]
        if "#Key" in l:
            chirp_eod[(index, df, contrast, cs)] = ([], [])
        if len(l.strip()) != 0 and "#" not in l:
            time = float(l.split()[0])
            ampl = float(l.split()[1])
            chirp_eod[(index, df, contrast, cs)][0].append(time)
            chirp_eod[(index, df, contrast, cs)][1].append(ampl)
    return chirp_eod


def load_chirp_times(dataset):
    chirp_times_file = os.path.join(dataset, "chirpss.dat")
    if not os.path.exists(chirp_times_file):
        print("found no chirpss.dat!")
        return {}
    with open(chirp_times_file, 'r') as f:
        lines = f.readlines()
    chirp_times = {}
    for l in lines:
        l = l.strip()
        if "index" in l and "chirp" not in l:
            index = int(l.split(":")[-1])
        if "deltaf" in l and "true" not in l:
            df = l.split(":")[-1]
        if "contrast" in l and "true" not in l:
            contrast = l.split(":")[-1]
        if "chirpsize" in l:
            cs = l.split(":")[-1]
        if "#Key" in l:
            chirp_times[(index, df, contrast, cs)] = []
        if len(l.strip()) != 0 and "#" not in l:
            chirp_times[(index, df, contrast, cs)].append(float(l.split()[1]))
    return chirp_times


if __name__ == "__main__":
    data_dir = "../data"
    dataset = "2018-11-09-ad-invivo-1"
    spikes = load_chirp_spikes(os.path.join(data_dir, dataset))
    chirp_times = load_chirp_times(os.path.join(data_dir, dataset))
    chirp_eod = load_chirp_eod(os.path.join(data_dir, dataset))