86 lines
2.7 KiB
Python
86 lines
2.7 KiB
Python
|
|
from DatParser import DatParser
|
|
from Controller import Controller, average_spike_height
|
|
|
|
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
from redetector import detect_spiketimes
|
|
|
|
DATA_FOLDER = "../neuronModel/data/final/"
|
|
|
|
failure_to_read_sam = ["2012-06-27-an-invivo-1", "2012-12-13-ag-invivo-1"]
|
|
|
|
|
|
def main():
|
|
|
|
test = np.load("redetected_spikes/spikes_repro_BaselineActivity_trial_0.npy")
|
|
print()
|
|
|
|
|
|
for cell in sorted(os.listdir(DATA_FOLDER)):
|
|
if cell in failure_to_read_sam:
|
|
continue
|
|
cell_folder = os.path.join(DATA_FOLDER, cell)
|
|
data_provider = Controller(cell_folder)
|
|
|
|
repros = test_getting_repros(data_provider)
|
|
|
|
print("\n", cell)
|
|
for repro in repros:
|
|
if not repro in data_provider.parser.spike_files.keys():
|
|
continue
|
|
|
|
print(repro)
|
|
traces, spiketimes, rec_times = data_provider.get_traces_with_spiketimes(repro)
|
|
sampling_interval = data_provider.parser.get_sampling_interval()
|
|
for i in range(len(traces)):
|
|
time = np.arange(len(traces[i])) * sampling_interval - rec_times[0]
|
|
plt.figure(figsize=(10, 5))
|
|
plt.plot(time, traces[i])
|
|
plt.eventplot(spiketimes[i], lineoffsets=max(traces[i]) + 1, colors="black")
|
|
redetect = detect_spiketimes(time, traces[i])
|
|
plt.eventplot(redetect, lineoffsets=max(traces[i]) + 2, colors="red")
|
|
# plt.savefig("figures/best_spikes_test/" + cell + "_" + repro + str(i) + ".png")
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
def test_loading_spikes(data_provider: Controller, repro):
|
|
return data_provider.parser.get_spiketimes(repro)
|
|
|
|
|
|
def test_loading_traces(data_provider, repro):
|
|
return data_provider.get_traces(repro)
|
|
|
|
|
|
def test_getting_repros(data_provider: Controller):
|
|
return data_provider.get_repros()
|
|
|
|
|
|
# def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
|
|
# ash = np.zeros((len(traces), len(spiketimes)))
|
|
#
|
|
# for i, trace in enumerate(traces):
|
|
# for j, spikes in enumerate(spiketimes):
|
|
# if len(spikes) <= 1:
|
|
# ash[i, j] = -np.infty
|
|
# else:
|
|
# ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
|
|
#
|
|
# return ash
|
|
#
|
|
|
|
# def average_spike_height(spike_train, v1, sampling_rate, before):
|
|
# indices = np.array([(s - before) / sampling_rate for s in spike_train], dtype=np.int)
|
|
# v1 = np.array(v1)
|
|
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
|
|
# average_height = np.mean(spike_values)
|
|
#
|
|
# return average_height
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|