from DatParser import DatParser from Controller import Controller, average_spike_height import os import numpy as np import matplotlib.pyplot as plt from redetector import detect_spiketimes DATA_FOLDER = "../neuronModel/data/final/" failure_to_read_sam = ["2012-06-27-an-invivo-1", "2012-12-13-ag-invivo-1"] def main(): test = np.load("redetected_spikes/spikes_repro_BaselineActivity_trial_0.npy") print() for cell in sorted(os.listdir(DATA_FOLDER)): if cell in failure_to_read_sam: continue cell_folder = os.path.join(DATA_FOLDER, cell) data_provider = Controller(cell_folder) repros = test_getting_repros(data_provider) print("\n", cell) for repro in repros: if not repro in data_provider.parser.spike_files.keys(): continue print(repro) traces, spiketimes, rec_times = data_provider.get_traces_with_spiketimes(repro) sampling_interval = data_provider.parser.get_sampling_interval() for i in range(len(traces)): time = np.arange(len(traces[i])) * sampling_interval - rec_times[0] plt.figure(figsize=(10, 5)) plt.plot(time, traces[i]) plt.eventplot(spiketimes[i], lineoffsets=max(traces[i]) + 1, colors="black") redetect = detect_spiketimes(time, traces[i]) plt.eventplot(redetect, lineoffsets=max(traces[i]) + 2, colors="red") # plt.savefig("figures/best_spikes_test/" + cell + "_" + repro + str(i) + ".png") plt.show() plt.close() def test_loading_spikes(data_provider: Controller, repro): return data_provider.parser.get_spiketimes(repro) def test_loading_traces(data_provider, repro): return data_provider.get_traces(repro) def test_getting_repros(data_provider: Controller): return data_provider.get_repros() # def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before): # ash = np.zeros((len(traces), len(spiketimes))) # # for i, trace in enumerate(traces): # for j, spikes in enumerate(spiketimes): # if len(spikes) <= 1: # ash[i, j] = -np.infty # else: # ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before) # # return ash # # def average_spike_height(spike_train, v1, sampling_rate, before): # indices = np.array([(s - before) / sampling_rate for s in spike_train], dtype=np.int) # v1 = np.array(v1) # spike_values = [v1[i] for i in indices if 0 <= i < len(v1)] # average_height = np.mean(spike_values) # # return average_height if __name__ == '__main__': main()