spikeRedetector/testing.py
2021-07-02 15:06:04 +02:00

86 lines
2.7 KiB
Python

from DatParser import DatParser
from Controller import Controller, average_spike_height
import os
import numpy as np
import matplotlib.pyplot as plt
from redetector import detect_spiketimes
DATA_FOLDER = "../neuronModel/data/final/"
failure_to_read_sam = ["2012-06-27-an-invivo-1", "2012-12-13-ag-invivo-1"]
def main():
test = np.load("redetected_spikes/spikes_repro_BaselineActivity_trial_0.npy")
print()
for cell in sorted(os.listdir(DATA_FOLDER)):
if cell in failure_to_read_sam:
continue
cell_folder = os.path.join(DATA_FOLDER, cell)
data_provider = Controller(cell_folder)
repros = test_getting_repros(data_provider)
print("\n", cell)
for repro in repros:
if not repro in data_provider.parser.spike_files.keys():
continue
print(repro)
traces, spiketimes, rec_times = data_provider.get_traces_with_spiketimes(repro)
sampling_interval = data_provider.parser.get_sampling_interval()
for i in range(len(traces)):
time = np.arange(len(traces[i])) * sampling_interval - rec_times[0]
plt.figure(figsize=(10, 5))
plt.plot(time, traces[i])
plt.eventplot(spiketimes[i], lineoffsets=max(traces[i]) + 1, colors="black")
redetect = detect_spiketimes(time, traces[i])
plt.eventplot(redetect, lineoffsets=max(traces[i]) + 2, colors="red")
# plt.savefig("figures/best_spikes_test/" + cell + "_" + repro + str(i) + ".png")
plt.show()
plt.close()
def test_loading_spikes(data_provider: Controller, repro):
return data_provider.parser.get_spiketimes(repro)
def test_loading_traces(data_provider, repro):
return data_provider.get_traces(repro)
def test_getting_repros(data_provider: Controller):
return data_provider.get_repros()
# def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
# ash = np.zeros((len(traces), len(spiketimes)))
#
# for i, trace in enumerate(traces):
# for j, spikes in enumerate(spiketimes):
# if len(spikes) <= 1:
# ash[i, j] = -np.infty
# else:
# ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
#
# return ash
#
# def average_spike_height(spike_train, v1, sampling_rate, before):
# indices = np.array([(s - before) / sampling_rate for s in spike_train], dtype=np.int)
# v1 = np.array(v1)
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
# average_height = np.mean(spike_values)
#
# return average_height
if __name__ == '__main__':
main()