backend base with saving of sorting/traces/spiketimes for the repros

This commit is contained in:
alexanderott 2021-02-19 11:40:19 +01:00
parent 92d14f189b
commit 38ba40ce2f
7 changed files with 396 additions and 68 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
*.dat *.dat
/data/ /data/
/temp/ /temp/
/figures/
# Latex output files # Latex output files
*.out *.out
*.aux *.aux

View File

@ -6,7 +6,7 @@ import numpy as np
import pyrelacs.DataLoader as Dl import pyrelacs.DataLoader as Dl
class DatParser(): class DatParser:
def __init__(self, dir_path): def __init__(self, dir_path):
self.base_path = dir_path self.base_path = dir_path
@ -15,18 +15,22 @@ class DatParser():
self.baseline_file = self.base_path + "/basespikes1.dat" self.baseline_file = self.base_path + "/basespikes1.dat"
self.sam_file = self.base_path + "/samallspikes1.dat" self.sam_file = self.base_path + "/samallspikes1.dat"
self.stimuli_file = self.base_path + "/stimuli.dat" self.stimuli_file = self.base_path + "/stimuli.dat"
self.__test_data_file_existence__() self.spike_files = {"BaselineActivity": self.baseline_file,
"FICurve": self.fi_file,
self.fi_recording_times = [] "FileStimulus": self.base_path + "/stimspikes1.dat",
"SAM": self.sam_file}
self.sampling_interval = -1 self.sampling_interval = -1
self.fi_recording_times = []
def has_sam_recordings(self): self.spiketimes = {}
return exists(self.sam_file) self.traces = {}
self.metadata = {}
def get_measured_repros(self): def get_measured_repros(self):
repros = [] repros = []
for metadata, key, data in Dl.iload(self.stimuli_file): for metadata, key, data in Dl.iload(self.stimuli_file):
repros.extend([d["repro"] for d in metadata if "repro" in d.keys()]) repros.extend([d["repro"] for d in metadata if "repro" in d.keys()])
repros.extend([d["RePro"] for d in metadata if "RePro" in d.keys()])
return sorted(np.unique(repros)) return sorted(np.unique(repros))
@ -110,15 +114,6 @@ class DatParser():
return np.array(contrasts) return np.array(contrasts)
def traces_available(self) -> bool:
return True
def frequencies_available(self) -> bool:
return False
def spiketimes_available(self) -> bool:
return True
def get_sampling_interval(self): def get_sampling_interval(self):
if self.sampling_interval == -1: if self.sampling_interval == -1:
self.__read_sampling_interval__() self.__read_sampling_interval__()
@ -131,7 +126,7 @@ class DatParser():
return self.fi_recording_times return self.fi_recording_times
def get_baseline_traces(self): def get_baseline_traces(self):
return self.__get_traces__("BaselineActivity") return self.get_traces("BaselineActivity")
def get_baseline_spiketimes(self): def get_baseline_spiketimes(self):
# TODO change: reading from file -> detect from v1 trace # TODO change: reading from file -> detect from v1 trace
@ -145,7 +140,7 @@ class DatParser():
return spiketimes return spiketimes
def get_fi_curve_traces(self): def get_fi_curve_traces(self):
return self.__get_traces__("FICurve") return self.get_traces("FICurve")
def get_fi_frequency_traces(self): def get_fi_frequency_traces(self):
raise NotImplementedError("Not possible in .dat data type.\n" raise NotImplementedError("Not possible in .dat data type.\n"
@ -234,7 +229,7 @@ class DatParser():
return trans_amplitudes, intensities, spiketimes return trans_amplitudes, intensities, spiketimes
def get_sam_traces(self): def get_sam_traces(self):
return self.__get_traces__("SAM") return self.get_traces("SAM")
def get_sam_info(self): def get_sam_info(self):
contrasts = [] contrasts = []
@ -258,7 +253,7 @@ class DatParser():
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
duration = float(metadata[0]["duration"][:-2]) * factor # normally saved in ms? so change it with the factor duration = float(metadata[0]["duration"][:-2]) * factor # normaly saved in ms? so change it with the factor
contrast = float(metadata[0]["contrast"][:-1]) # in percent contrast = float(metadata[0]["contrast"][:-1]) # in percent
delta_f = float(metadata[0]["deltaf"][:-2]) delta_f = float(metadata[0]["deltaf"][:-2])
else: else:
@ -267,7 +262,7 @@ class DatParser():
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
duration = float(stimulus_dict["duration"][:-2]) * factor # normally saved in ms? so change it with the factor duration = float(stimulus_dict["duration"][:-2]) * factor # normaly saved in ms? so change it with the factor
contrast = float(stimulus_dict["contrast"][:-1]) # in percent contrast = float(stimulus_dict["contrast"][:-1]) # in percent
delta_f = float(stimulus_dict["deltaf"][:-2]) delta_f = float(stimulus_dict["deltaf"][:-2])
@ -295,7 +290,32 @@ class DatParser():
return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes
def __get_traces__(self, repro): def get_spiketimes(self, repro, spiketimes_file=""):
if repro in self.spiketimes.keys():
return self.spiketimes[repro], self.metadata[repro]
spiketimes = []
metadata_list = []
warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!")
if spiketimes_file == "":
if repro not in self.spike_files.keys():
raise NotImplementedError("No spiketime file specified in DatParser for repro {}! "
"Please specify spiketimes_file argument or add it to the spike_files dictionary in the Datparser class.".format(repro))
spiketimes_file = self.spike_files[repro]
for metadata, key, data in Dl.iload(spiketimes_file):
metadata_list.append(metadata)
spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds
spiketimes.append(spikes)
self.spiketimes[repro] = spiketimes
self.metadata[repro] = metadata_list
return spiketimes, metadata_list
def get_traces(self, repro, before=0, after=0):
if repro in self.traces.keys():
return self.traces[repro]
time_traces = [] time_traces = []
v1_traces = [] v1_traces = []
eod_traces = [] eod_traces = []
@ -304,7 +324,7 @@ class DatParser():
nothing = True nothing = True
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro): for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro, before=before, after=after):
nothing = False nothing = False
time_traces.append(time) time_traces.append(time)
v1_traces.append(x[0]) v1_traces.append(x[0])
@ -318,14 +338,9 @@ class DatParser():
warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!" warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
warn(warn_msg) warn(warn_msg)
self.traces[repro] = traces
return traces return traces
def __iget_traces__(self, repro):
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
# time, v1, eod, local_eod, stimulus
yield time, x[0], x[1], x[2], x[3]
def __read_fi_recording_times__(self): def __read_fi_recording_times__(self):
delays = [] delays = []
@ -386,13 +401,3 @@ class DatParser():
"with File:" + self.base_path) "with File:" + self.base_path)
else: else:
self.sampling_interval = sampling_intervals[0] self.sampling_interval = sampling_intervals[0]
def __test_data_file_existence__(self):
if not exists(self.stimuli_file):
raise FileNotFoundError(self.stimuli_file + " file doesn't exist!")
if not exists(self.fi_file):
raise FileNotFoundError(self.fi_file + " file doesn't exist!")
if not exists(self.baseline_file):
raise FileNotFoundError(self.baseline_file + " file doesn't exist!")
# if not exists(self.sam_file):
# raise RuntimeError(self.sam_file + " file doesn't exist!")

View File

@ -1,20 +1,138 @@
from DatParser import DatParser from DatParser import DatParser
import numpy as np
from warnings import warn
import matplotlib.pyplot as plt
class DataProvider: class DataProvider:
def __init__(self, data_path): def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
self.data_path = data_path self.data_path = data_path
# self.cell = CellData(data_path)
self.parser = DatParser(data_path) self.parser = DatParser(data_path)
self.repros = self.get_repros()
if len(repros_todo) != 0:
not_available = []
repros_todo = list(np.unique(repros_todo))
for repro in repros_todo:
if repro not in self.repros:
not_available.append(repro)
print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro))
for x in not_available:
repros_todo.remove(x)
self.repros_todo = repros_todo
else:
self.repros_todo = self.repros
self.base_Detection_parameters = base_detection_params
self.thresholds = {} self.thresholds = {}
self.sorting = {}
self.recording_times = {}
def get_repros(self): def get_repros(self):
pass return self.parser.get_measured_repros()
def get_unsorted_spiketimes(self, repro, file=""):
return self.parser.get_spiketimes(repro, file)
def get_unsorted_traces(self, repro, before=0, after=0):
return self.parser.get_traces(repro, before, after)
def get_traces_with_spiketimes(self, repro):
if repro in self.sorting.keys():
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
v1_traces = traces[1]
spiketimes = self.get_unsorted_spiketimes(repro)
sorted_spiketimes = np.array(spiketimes)[self.sorting[repro]]
return v1_traces, sorted_spiketimes, self.recording_times[repro]
if repro == "FICurve":
recording_times = self.parser.get_recording_times()
before = abs(recording_times[0])
after = recording_times[3]
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
spiketimes = self.get_unsorted_spiketimes(repro)
else:
before = 0
after = 0
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
spiketimes = self.get_unsorted_spiketimes(repro)
if len(v1_traces) != len(spiketimes):
warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!"
"Returning only according to the number of traces!".format(repro))
ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before)
sorted_spiketimes = []
sorting = []
for i in range(len(v1_traces)):
best = np.argmax(ash[i, :])
sorting.append(best)
sorted_spiketimes.append(spiketimes[best])
self.sorting[repro] = sorting
self.recording_times[repro] = (before, after)
return v1_traces, sorted_spiketimes, (before, after)
def get_stim_values(self, repro): def get_stim_values(self, repro):
pass if repro == "BaselineActivity":
return ["None"]
elif repro == "FICurve":
# TODO other function that just provides the contrasts
return sorted([str(x[0]) for x in self.parser.get_fi_curve_contrasts()])
return ["repro not supported"]
def get_trials(self, repro, stimulus_value): def get_trials(self, repro, stimulus_value):
pass pass
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
ash = np.zeros((len(traces), len(spiketimes)))
total = len(traces) * len(spiketimes)
count = 0
for i, trace in enumerate(traces):
for j, spikes in enumerate(spiketimes):
count += 1
if count % 50000 == 0:
print("{} / {}".format(count, total))
if len(spikes) <= 1:
ash[i, j] = -np.infty
else:
ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
return ash
def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before):
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
indices = (spike_train + before) / sampling_rate
indices = np.array(indices, dtype=np.int)
if len(indices) <= 1:
return -np.infty
# [v1[i] for i in indices if 0 <= i < len(v1)]
applicable_indices = indices[(indices < len(v1)) & (indices > 0)]
spike_values = v1[applicable_indices]
average_height = np.mean(spike_values)
return average_height
# SLOW:
# def average_spike_height(spike_train, v1, sampling_rate, before):
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
# if len(indices) <= 1:
# return -np.infty
# v1 = np.array(v1)
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
# average_height = np.mean(spike_values)
#
# return average_height

View File

@ -1,18 +1,19 @@
from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\ from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\ QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
QDoubleSpinBox, QComboBox QDoubleSpinBox, QComboBox, QSpinBox
from PyQt5.QtCore import pyqtSlot from PyQt5.QtCore import pyqtSlot
import numpy as np import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure from matplotlib.figure import Figure
from DataProvider import DataProvider
class SpikeRedetectGui(QWidget): class SpikeRedetectGui(QWidget):
def __init__(self, data_provider): def __init__(self, data_provider: DataProvider):
super().__init__() super().__init__()
self.data_provider = data_provider self.data_provider = data_provider
self.title = 'Spike Redetection' self.title = 'Spike Redetection'
@ -29,36 +30,80 @@ class SpikeRedetectGui(QWidget):
# Middle: # Middle:
middle = QHBoxLayout() middle = QHBoxLayout()
# Canvas for matplotlib figure # Canvas Area for matplotlib figure
m = PlotCanvas(self, width=5, height=4)
plot_area = QFrame()
plot_area_layout = QVBoxLayout()
m = PlotCanvas(self)
m.move(0, 0) m.move(0, 0)
middle.addWidget(m) plot_area_layout.addWidget(m)
# plot area buttons
plot_area_buttons = QFrame()
plot_area_buttons_layout = QHBoxLayout()
plot_area_buttons.setLayout(plot_area_buttons_layout)
plot_area_layout.addWidget(plot_area_buttons)
button = QPushButton('Button1', self)
button.setToolTip('A nice button!')
button.clicked.connect(lambda: threshold_spinbox.setValue(1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button2', self)
button.setToolTip('Another nice button!')
button.clicked.connect(lambda: threshold_spinbox.setValue(2))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button3', self)
button.setToolTip('Even more nice buttons!')
button.clicked.connect(lambda: threshold_spinbox.setValue(3))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button4', self)
button.setToolTip('Even more nice buttons!')
button.clicked.connect(lambda: threshold_spinbox.setValue(4))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
plot_area.setLayout(plot_area_layout)
middle.addWidget(plot_area)
middle.addWidget(QVLine()) middle.addWidget(QVLine())
# Side (options) panel # Side (options) panel
panel = QFrame() panel = QFrame()
panel.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
panel.setMaximumWidth(200)
panel_layout = QVBoxLayout() panel_layout = QVBoxLayout()
button = QPushButton('Button!', self)
button.setToolTip('A nice button!')
button.clicked.connect(lambda: threshold_spinbox.setValue(1)) stim_val_label = QLabel("Stimulus value:")
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) self.stim_val_box = QComboBox()
panel_layout.addWidget(button)
repro_label = QLabel("Repro:") repro_label = QLabel("Repro:")
panel_layout.addWidget(repro_label) panel_layout.addWidget(repro_label)
self.repro_box = QComboBox() self.repro_box = QComboBox()
self.repro_box.addItem("placeholder repro") self.repro_box.currentTextChanged.connect(self.repro_change)
panel_layout.addWidget(self.repro_box) for repro in self.data_provider.get_repros():
self.repro_box.addItem(repro)
stim_val_label = QLabel("Stimulus value:") panel_layout.addWidget(self.repro_box)
panel_layout.addWidget(stim_val_label) panel_layout.addWidget(stim_val_label)
self.stim_val_box = QComboBox()
self.stim_val_box.addItem("placeholder stim_value")
panel_layout.addWidget(self.stim_val_box) panel_layout.addWidget(self.stim_val_box)
filler = QFill(maxh=20) trial_label = QLabel("Trial:")
panel_layout.addWidget(trial_label)
threshold_spinbox = QSpinBox(self)
threshold_spinbox.setValue(1)
threshold_spinbox.setSingleStep(1)
threshold_spinbox.valueChanged.connect()
panel_layout.addWidget(threshold_spinbox)
filler = QFill(minh=200)
panel_layout.addWidget(filler) panel_layout.addWidget(filler)
self.status_label = QLabel("Done x/15 Stimulus Values") self.status_label = QLabel("Done x/15 Stimulus Values")
@ -78,18 +123,25 @@ class SpikeRedetectGui(QWidget):
button.setToolTip('Accept the threshold for current stimulus value') button.setToolTip('Accept the threshold for current stimulus value')
panel_layout.addWidget(button) panel_layout.addWidget(button)
panel.setLayout(panel_layout) panel.setLayout(panel_layout)
middle.addWidget(panel) middle.addWidget(panel)
self.setLayout(middle) self.setLayout(middle)
self.show() self.show()
@pyqtSlot()
def repro_change(self):
repro = self.repro_box.currentText()
self.stim_val_box.clear()
for val in self.data_provider.get_stim_values(repro):
self.stim_val_box.addItem(str(val))
class PlotCanvas(FigureCanvas): class PlotCanvas(FigureCanvas):
def __init__(self, parent=None, width=5, height=4, dpi=100): def __init__(self, parent=None, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi) fig = Figure(dpi=dpi)
self.axes = fig.add_subplot(111) self.axes = fig.add_subplot(111)
FigureCanvas.__init__(self, fig) FigureCanvas.__init__(self, fig)
@ -107,7 +159,7 @@ class PlotCanvas(FigureCanvas):
data = np.sin(x*np.pi*2*mean) data = np.sin(x*np.pi*2*mean)
ax = self.axes ax = self.axes
ax.clear() ax.clear()
ax.plot(data, 'r-') ax.plot(x, data, 'r-')
ax.set_title('Sinus Example') ax.set_title('Sinus Example')
self.draw() self.draw()
@ -127,7 +179,8 @@ class QVLine(QFrame):
class QFill(QFrame): class QFill(QFrame):
def __init__(self, maxh=int(2**24)-1, maxw=int(2**24)-1): def __init__(self, maxw=int(2**24)-1, maxh=int(2**24)-1, minw=0, minh=0):
super(QFill, self).__init__() super(QFill, self).__init__()
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.setMaximumSize(maxw, maxh) self.setMaximumSize(maxw, maxh)
self.setMinimumSize(minw, minh)

View File

@ -1,14 +1,14 @@
import sys import sys
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton from PyQt5.QtWidgets import QApplication
from spike_redetection.DataProvider import DataProvider from DataProvider import DataProvider
from spike_redetection.SpikeRedetectGui import SpikeRedetectGui from SpikeRedetectGui import SpikeRedetectGui
def main(): def main():
app = QApplication(sys.argv) app = QApplication(sys.argv)
data_provider = DataProvider("../data/final_sam/2010-11-08-al-invivo-1") data_provider = DataProvider("../neuronModel/data/final_sam/2010-11-08-al-invivo-1")
ex = SpikeRedetectGui(data_provider) ex = SpikeRedetectGui(data_provider)
sys.exit(app.exec_()) sys.exit(app.exec_())

71
redetector.py Normal file
View File

@ -0,0 +1,71 @@
import numpy as np
from thunderfish.eventdetection import detect_peaks
def detect_spiketimes(time: np.ndarray, v1, threshold=2.0, min_length=5000, split_step=1000):
all_peak_indicies = detect_spike_indices_automatic_split(v1, threshold, min_length, split_step)
return time[all_peak_indicies]
def detect_spike_indices_automatic_split(v1, threshold, min_length=5000, split_step=1000):
split_start = 0
step_size = split_step
break_threshold = 0.25
splits = []
if len(v1) < min_length:
splits = [(0, len(v1))]
else:
last_max = max(v1[0:min_length])
last_min = min(v1[0:min_length])
idx = min_length
while idx < len(v1):
if idx + step_size > len(v1):
splits.append((split_start, len(v1)))
break
max_similar = abs((max(v1[idx:idx+step_size]) - last_max) / last_max) < break_threshold
min_similar = abs((min(v1[idx:idx+step_size]) - last_min) / last_min) < break_threshold
if not max_similar or not min_similar:
# print("new split")
end_idx = np.argmin(v1[idx-20:idx+21]) - 20
splits.append((split_start, idx+end_idx))
split_start = idx+end_idx
last_max = max(v1[split_start:split_start + min_length])
last_min = min(v1[split_start:split_start + min_length])
idx = split_start + min_length
continue
else:
pass
# print("elongated!")
idx += step_size
if splits[-1][1] != len(v1):
splits.append((split_start, len(v1)))
# plt.plot(v1)
# for s in splits:
# plt.plot(s, (max(v1[s[0]:s[1]]), max(v1[s[0]:s[1]])))
all_peaks = []
for s in splits:
first_index = s[0]
last_index = s[1]
std = np.std(v1[first_index:last_index])
peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold)
peaks = peaks + first_index
# plt.plot(peaks, [np.max(v1[first_index:last_index]) for _ in peaks], 'o')
all_peaks.extend(peaks)
# plt.show()
# plt.close()
# all_peaks = np.array(all_peaks)
return all_peaks

80
testing.py Normal file
View File

@ -0,0 +1,80 @@
from DatParser import DatParser
from DataProvider import DataProvider, average_spike_height
import os
import numpy as np
import matplotlib.pyplot as plt
from redetector import detect_spiketimes
DATA_FOLDER = "../neuronModel/data/final/"
failure_to_read_sam = ["2012-06-27-an-invivo-1", "2012-12-13-ag-invivo-1"]
def main():
for cell in sorted(os.listdir(DATA_FOLDER)):
if cell in failure_to_read_sam:
continue
cell_folder = os.path.join(DATA_FOLDER, cell)
data_provider = DataProvider(cell_folder)
repros = test_getting_repros(data_provider)
print("\n", cell)
for repro in repros:
if not repro in data_provider.parser.spike_files.keys():
continue
print(repro)
traces, spiketimes, rec_times = data_provider.get_traces_with_spiketimes(repro)
sampling_interval = data_provider.parser.get_sampling_interval()
for i in range(len(traces)):
time = np.arange(len(traces[i])) * sampling_interval - rec_times[0]
plt.figure(figsize=(10, 5))
plt.plot(time, traces[i])
plt.eventplot(spiketimes[i], lineoffsets=max(traces[i]) + 1, colors="black")
redetect = detect_spiketimes(time, traces[i])
plt.eventplot(redetect, lineoffsets=max(traces[i]) + 2, colors="red")
# plt.savefig("figures/best_spikes_test/" + cell + "_" + repro + str(i) + ".png")
plt.show()
plt.close()
def test_loading_spikes(data_provider: DataProvider, repro):
return data_provider.parser.get_spiketimes(repro)
def test_loading_traces(data_provider, repro):
return data_provider.get_traces(repro)
def test_getting_repros(data_provider: DataProvider):
return data_provider.get_repros()
# def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
# ash = np.zeros((len(traces), len(spiketimes)))
#
# for i, trace in enumerate(traces):
# for j, spikes in enumerate(spiketimes):
# if len(spikes) <= 1:
# ash[i, j] = -np.infty
# else:
# ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
#
# return ash
#
# def average_spike_height(spike_train, v1, sampling_rate, before):
# indices = np.array([(s - before) / sampling_rate for s in spike_train], dtype=np.int)
# v1 = np.array(v1)
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
# average_height = np.mean(spike_values)
#
# return average_height
if __name__ == '__main__':
main()