import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as ss
from read_chirp_data import *
from utility import *
from IPython import embed

# define sampling rate and data path
sampling_rate = 40 #kHz
data_dir = "../data"
cut_window = 20

data = ["2018-11-13-aa-invivo-1", "2018-11-13-ac-invivo-1", "2018-11-13-ad-invivo-1", "2018-11-13-ah-invivo-1",
        "2018-11-13-ai-invivo-1", "2018-11-13-aj-invivo-1", "2018-11-13-ak-invivo-1", "2018-11-13-al-invivo-1"]

for dataset in data:
    spikes = read_chirp_spikes(os.path.join(data_dir, dataset))
    df_map = map_keys(spikes)
    print(dataset)
    for df in df_map.keys():
        beat_duration = 1/df
        beat_window = 0
        while beat_window + beat_duration <= cut_window:
            beat_window = beat_window + beat_duration
        for rep in df_map[df]:
            for phase in spikes[rep]:
                response = spikes[rep][phase]
                break
            #cut = response[response[]]