backend base with saving of sorting/traces/spiketimes for the repros
This commit is contained in:
parent
92d14f189b
commit
38ba40ce2f
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
*.dat
|
||||
/data/
|
||||
/temp/
|
||||
/figures/
|
||||
# Latex output files
|
||||
*.out
|
||||
*.aux
|
||||
|
81
DatParser.py
81
DatParser.py
@ -6,7 +6,7 @@ import numpy as np
|
||||
import pyrelacs.DataLoader as Dl
|
||||
|
||||
|
||||
class DatParser():
|
||||
class DatParser:
|
||||
|
||||
def __init__(self, dir_path):
|
||||
self.base_path = dir_path
|
||||
@ -15,18 +15,22 @@ class DatParser():
|
||||
self.baseline_file = self.base_path + "/basespikes1.dat"
|
||||
self.sam_file = self.base_path + "/samallspikes1.dat"
|
||||
self.stimuli_file = self.base_path + "/stimuli.dat"
|
||||
self.__test_data_file_existence__()
|
||||
|
||||
self.fi_recording_times = []
|
||||
self.spike_files = {"BaselineActivity": self.baseline_file,
|
||||
"FICurve": self.fi_file,
|
||||
"FileStimulus": self.base_path + "/stimspikes1.dat",
|
||||
"SAM": self.sam_file}
|
||||
self.sampling_interval = -1
|
||||
self.fi_recording_times = []
|
||||
|
||||
def has_sam_recordings(self):
|
||||
return exists(self.sam_file)
|
||||
self.spiketimes = {}
|
||||
self.traces = {}
|
||||
self.metadata = {}
|
||||
|
||||
def get_measured_repros(self):
|
||||
repros = []
|
||||
for metadata, key, data in Dl.iload(self.stimuli_file):
|
||||
repros.extend([d["repro"] for d in metadata if "repro" in d.keys()])
|
||||
repros.extend([d["RePro"] for d in metadata if "RePro" in d.keys()])
|
||||
|
||||
return sorted(np.unique(repros))
|
||||
|
||||
@ -110,15 +114,6 @@ class DatParser():
|
||||
|
||||
return np.array(contrasts)
|
||||
|
||||
def traces_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def frequencies_available(self) -> bool:
|
||||
return False
|
||||
|
||||
def spiketimes_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_sampling_interval(self):
|
||||
if self.sampling_interval == -1:
|
||||
self.__read_sampling_interval__()
|
||||
@ -131,7 +126,7 @@ class DatParser():
|
||||
return self.fi_recording_times
|
||||
|
||||
def get_baseline_traces(self):
|
||||
return self.__get_traces__("BaselineActivity")
|
||||
return self.get_traces("BaselineActivity")
|
||||
|
||||
def get_baseline_spiketimes(self):
|
||||
# TODO change: reading from file -> detect from v1 trace
|
||||
@ -145,7 +140,7 @@ class DatParser():
|
||||
return spiketimes
|
||||
|
||||
def get_fi_curve_traces(self):
|
||||
return self.__get_traces__("FICurve")
|
||||
return self.get_traces("FICurve")
|
||||
|
||||
def get_fi_frequency_traces(self):
|
||||
raise NotImplementedError("Not possible in .dat data type.\n"
|
||||
@ -234,7 +229,7 @@ class DatParser():
|
||||
return trans_amplitudes, intensities, spiketimes
|
||||
|
||||
def get_sam_traces(self):
|
||||
return self.__get_traces__("SAM")
|
||||
return self.get_traces("SAM")
|
||||
|
||||
def get_sam_info(self):
|
||||
contrasts = []
|
||||
@ -258,7 +253,7 @@ class DatParser():
|
||||
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
|
||||
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
|
||||
|
||||
duration = float(metadata[0]["duration"][:-2]) * factor # normally saved in ms? so change it with the factor
|
||||
duration = float(metadata[0]["duration"][:-2]) * factor # normaly saved in ms? so change it with the factor
|
||||
contrast = float(metadata[0]["contrast"][:-1]) # in percent
|
||||
delta_f = float(metadata[0]["deltaf"][:-2])
|
||||
else:
|
||||
@ -267,7 +262,7 @@ class DatParser():
|
||||
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
|
||||
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
|
||||
|
||||
duration = float(stimulus_dict["duration"][:-2]) * factor # normally saved in ms? so change it with the factor
|
||||
duration = float(stimulus_dict["duration"][:-2]) * factor # normaly saved in ms? so change it with the factor
|
||||
contrast = float(stimulus_dict["contrast"][:-1]) # in percent
|
||||
delta_f = float(stimulus_dict["deltaf"][:-2])
|
||||
|
||||
@ -295,7 +290,32 @@ class DatParser():
|
||||
|
||||
return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes
|
||||
|
||||
def __get_traces__(self, repro):
|
||||
def get_spiketimes(self, repro, spiketimes_file=""):
|
||||
if repro in self.spiketimes.keys():
|
||||
return self.spiketimes[repro], self.metadata[repro]
|
||||
|
||||
spiketimes = []
|
||||
metadata_list = []
|
||||
warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!")
|
||||
|
||||
if spiketimes_file == "":
|
||||
if repro not in self.spike_files.keys():
|
||||
raise NotImplementedError("No spiketime file specified in DatParser for repro {}! "
|
||||
"Please specify spiketimes_file argument or add it to the spike_files dictionary in the Datparser class.".format(repro))
|
||||
spiketimes_file = self.spike_files[repro]
|
||||
|
||||
for metadata, key, data in Dl.iload(spiketimes_file):
|
||||
metadata_list.append(metadata)
|
||||
spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds
|
||||
spiketimes.append(spikes)
|
||||
|
||||
self.spiketimes[repro] = spiketimes
|
||||
self.metadata[repro] = metadata_list
|
||||
return spiketimes, metadata_list
|
||||
|
||||
def get_traces(self, repro, before=0, after=0):
|
||||
if repro in self.traces.keys():
|
||||
return self.traces[repro]
|
||||
time_traces = []
|
||||
v1_traces = []
|
||||
eod_traces = []
|
||||
@ -304,7 +324,7 @@ class DatParser():
|
||||
|
||||
nothing = True
|
||||
|
||||
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
|
||||
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro, before=before, after=after):
|
||||
nothing = False
|
||||
time_traces.append(time)
|
||||
v1_traces.append(x[0])
|
||||
@ -318,14 +338,9 @@ class DatParser():
|
||||
warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
|
||||
warn(warn_msg)
|
||||
|
||||
self.traces[repro] = traces
|
||||
return traces
|
||||
|
||||
def __iget_traces__(self, repro):
|
||||
|
||||
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
|
||||
# time, v1, eod, local_eod, stimulus
|
||||
yield time, x[0], x[1], x[2], x[3]
|
||||
|
||||
def __read_fi_recording_times__(self):
|
||||
|
||||
delays = []
|
||||
@ -386,13 +401,3 @@ class DatParser():
|
||||
"with File:" + self.base_path)
|
||||
else:
|
||||
self.sampling_interval = sampling_intervals[0]
|
||||
|
||||
def __test_data_file_existence__(self):
|
||||
if not exists(self.stimuli_file):
|
||||
raise FileNotFoundError(self.stimuli_file + " file doesn't exist!")
|
||||
if not exists(self.fi_file):
|
||||
raise FileNotFoundError(self.fi_file + " file doesn't exist!")
|
||||
if not exists(self.baseline_file):
|
||||
raise FileNotFoundError(self.baseline_file + " file doesn't exist!")
|
||||
# if not exists(self.sam_file):
|
||||
# raise RuntimeError(self.sam_file + " file doesn't exist!")
|
||||
|
126
DataProvider.py
126
DataProvider.py
@ -1,20 +1,138 @@
|
||||
|
||||
from DatParser import DatParser
|
||||
|
||||
import numpy as np
|
||||
from warnings import warn
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class DataProvider:
|
||||
|
||||
def __init__(self, data_path):
|
||||
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
|
||||
self.data_path = data_path
|
||||
# self.cell = CellData(data_path)
|
||||
self.parser = DatParser(data_path)
|
||||
self.repros = self.get_repros()
|
||||
|
||||
if len(repros_todo) != 0:
|
||||
not_available = []
|
||||
repros_todo = list(np.unique(repros_todo))
|
||||
for repro in repros_todo:
|
||||
if repro not in self.repros:
|
||||
not_available.append(repro)
|
||||
print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro))
|
||||
for x in not_available:
|
||||
repros_todo.remove(x)
|
||||
|
||||
self.repros_todo = repros_todo
|
||||
else:
|
||||
self.repros_todo = self.repros
|
||||
|
||||
self.base_Detection_parameters = base_detection_params
|
||||
self.thresholds = {}
|
||||
|
||||
self.sorting = {}
|
||||
self.recording_times = {}
|
||||
|
||||
def get_repros(self):
|
||||
pass
|
||||
return self.parser.get_measured_repros()
|
||||
|
||||
def get_unsorted_spiketimes(self, repro, file=""):
|
||||
return self.parser.get_spiketimes(repro, file)
|
||||
|
||||
def get_unsorted_traces(self, repro, before=0, after=0):
|
||||
return self.parser.get_traces(repro, before, after)
|
||||
|
||||
def get_traces_with_spiketimes(self, repro):
|
||||
if repro in self.sorting.keys():
|
||||
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
|
||||
v1_traces = traces[1]
|
||||
spiketimes = self.get_unsorted_spiketimes(repro)
|
||||
|
||||
sorted_spiketimes = np.array(spiketimes)[self.sorting[repro]]
|
||||
return v1_traces, sorted_spiketimes, self.recording_times[repro]
|
||||
|
||||
if repro == "FICurve":
|
||||
recording_times = self.parser.get_recording_times()
|
||||
before = abs(recording_times[0])
|
||||
after = recording_times[3]
|
||||
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
|
||||
spiketimes = self.get_unsorted_spiketimes(repro)
|
||||
|
||||
else:
|
||||
before = 0
|
||||
after = 0
|
||||
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
|
||||
spiketimes = self.get_unsorted_spiketimes(repro)
|
||||
|
||||
if len(v1_traces) != len(spiketimes):
|
||||
warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!"
|
||||
"Returning only according to the number of traces!".format(repro))
|
||||
|
||||
ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before)
|
||||
|
||||
sorted_spiketimes = []
|
||||
sorting = []
|
||||
for i in range(len(v1_traces)):
|
||||
best = np.argmax(ash[i, :])
|
||||
sorting.append(best)
|
||||
sorted_spiketimes.append(spiketimes[best])
|
||||
self.sorting[repro] = sorting
|
||||
self.recording_times[repro] = (before, after)
|
||||
|
||||
return v1_traces, sorted_spiketimes, (before, after)
|
||||
|
||||
def get_stim_values(self, repro):
|
||||
pass
|
||||
if repro == "BaselineActivity":
|
||||
return ["None"]
|
||||
elif repro == "FICurve":
|
||||
# TODO other function that just provides the contrasts
|
||||
return sorted([str(x[0]) for x in self.parser.get_fi_curve_contrasts()])
|
||||
|
||||
return ["repro not supported"]
|
||||
|
||||
def get_trials(self, repro, stimulus_value):
|
||||
pass
|
||||
|
||||
|
||||
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
|
||||
ash = np.zeros((len(traces), len(spiketimes)))
|
||||
total = len(traces) * len(spiketimes)
|
||||
count = 0
|
||||
for i, trace in enumerate(traces):
|
||||
for j, spikes in enumerate(spiketimes):
|
||||
count += 1
|
||||
if count % 50000 == 0:
|
||||
print("{} / {}".format(count, total))
|
||||
if len(spikes) <= 1:
|
||||
ash[i, j] = -np.infty
|
||||
else:
|
||||
ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
|
||||
|
||||
return ash
|
||||
|
||||
|
||||
def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before):
|
||||
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
|
||||
indices = (spike_train + before) / sampling_rate
|
||||
indices = np.array(indices, dtype=np.int)
|
||||
if len(indices) <= 1:
|
||||
return -np.infty
|
||||
# [v1[i] for i in indices if 0 <= i < len(v1)]
|
||||
|
||||
applicable_indices = indices[(indices < len(v1)) & (indices > 0)]
|
||||
spike_values = v1[applicable_indices]
|
||||
average_height = np.mean(spike_values)
|
||||
|
||||
return average_height
|
||||
|
||||
|
||||
# SLOW:
|
||||
# def average_spike_height(spike_train, v1, sampling_rate, before):
|
||||
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
|
||||
# if len(indices) <= 1:
|
||||
# return -np.infty
|
||||
# v1 = np.array(v1)
|
||||
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
|
||||
# average_height = np.mean(spike_values)
|
||||
#
|
||||
# return average_height
|
@ -1,18 +1,19 @@
|
||||
|
||||
from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
|
||||
QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
|
||||
QDoubleSpinBox, QComboBox
|
||||
QDoubleSpinBox, QComboBox, QSpinBox
|
||||
from PyQt5.QtCore import pyqtSlot
|
||||
|
||||
import numpy as np
|
||||
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.figure import Figure
|
||||
from DataProvider import DataProvider
|
||||
|
||||
|
||||
class SpikeRedetectGui(QWidget):
|
||||
|
||||
def __init__(self, data_provider):
|
||||
def __init__(self, data_provider: DataProvider):
|
||||
super().__init__()
|
||||
self.data_provider = data_provider
|
||||
self.title = 'Spike Redetection'
|
||||
@ -29,36 +30,80 @@ class SpikeRedetectGui(QWidget):
|
||||
# Middle:
|
||||
middle = QHBoxLayout()
|
||||
|
||||
# Canvas for matplotlib figure
|
||||
m = PlotCanvas(self, width=5, height=4)
|
||||
# Canvas Area for matplotlib figure
|
||||
|
||||
plot_area = QFrame()
|
||||
plot_area_layout = QVBoxLayout()
|
||||
m = PlotCanvas(self)
|
||||
m.move(0, 0)
|
||||
middle.addWidget(m)
|
||||
plot_area_layout.addWidget(m)
|
||||
|
||||
# plot area buttons
|
||||
plot_area_buttons = QFrame()
|
||||
plot_area_buttons_layout = QHBoxLayout()
|
||||
plot_area_buttons.setLayout(plot_area_buttons_layout)
|
||||
plot_area_layout.addWidget(plot_area_buttons)
|
||||
|
||||
button = QPushButton('Button1', self)
|
||||
button.setToolTip('A nice button!')
|
||||
button.clicked.connect(lambda: threshold_spinbox.setValue(1))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
plot_area_buttons_layout.addWidget(button)
|
||||
|
||||
button = QPushButton('Button2', self)
|
||||
button.setToolTip('Another nice button!')
|
||||
button.clicked.connect(lambda: threshold_spinbox.setValue(2))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
plot_area_buttons_layout.addWidget(button)
|
||||
|
||||
button = QPushButton('Button3', self)
|
||||
button.setToolTip('Even more nice buttons!')
|
||||
button.clicked.connect(lambda: threshold_spinbox.setValue(3))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
plot_area_buttons_layout.addWidget(button)
|
||||
|
||||
button = QPushButton('Button4', self)
|
||||
button.setToolTip('Even more nice buttons!')
|
||||
button.clicked.connect(lambda: threshold_spinbox.setValue(4))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
plot_area_buttons_layout.addWidget(button)
|
||||
|
||||
plot_area.setLayout(plot_area_layout)
|
||||
|
||||
middle.addWidget(plot_area)
|
||||
middle.addWidget(QVLine())
|
||||
|
||||
# Side (options) panel
|
||||
panel = QFrame()
|
||||
panel.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
|
||||
panel.setMaximumWidth(200)
|
||||
panel_layout = QVBoxLayout()
|
||||
|
||||
button = QPushButton('Button!', self)
|
||||
button.setToolTip('A nice button!')
|
||||
button.clicked.connect(lambda: threshold_spinbox.setValue(1))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
panel_layout.addWidget(button)
|
||||
|
||||
|
||||
stim_val_label = QLabel("Stimulus value:")
|
||||
self.stim_val_box = QComboBox()
|
||||
|
||||
repro_label = QLabel("Repro:")
|
||||
panel_layout.addWidget(repro_label)
|
||||
|
||||
self.repro_box = QComboBox()
|
||||
self.repro_box.addItem("placeholder repro")
|
||||
panel_layout.addWidget(self.repro_box)
|
||||
self.repro_box.currentTextChanged.connect(self.repro_change)
|
||||
for repro in self.data_provider.get_repros():
|
||||
self.repro_box.addItem(repro)
|
||||
|
||||
stim_val_label = QLabel("Stimulus value:")
|
||||
panel_layout.addWidget(self.repro_box)
|
||||
panel_layout.addWidget(stim_val_label)
|
||||
self.stim_val_box = QComboBox()
|
||||
self.stim_val_box.addItem("placeholder stim_value")
|
||||
panel_layout.addWidget(self.stim_val_box)
|
||||
|
||||
filler = QFill(maxh=20)
|
||||
trial_label = QLabel("Trial:")
|
||||
panel_layout.addWidget(trial_label)
|
||||
threshold_spinbox = QSpinBox(self)
|
||||
threshold_spinbox.setValue(1)
|
||||
threshold_spinbox.setSingleStep(1)
|
||||
threshold_spinbox.valueChanged.connect()
|
||||
panel_layout.addWidget(threshold_spinbox)
|
||||
|
||||
filler = QFill(minh=200)
|
||||
panel_layout.addWidget(filler)
|
||||
|
||||
self.status_label = QLabel("Done x/15 Stimulus Values")
|
||||
@ -78,18 +123,25 @@ class SpikeRedetectGui(QWidget):
|
||||
button.setToolTip('Accept the threshold for current stimulus value')
|
||||
panel_layout.addWidget(button)
|
||||
|
||||
|
||||
panel.setLayout(panel_layout)
|
||||
middle.addWidget(panel)
|
||||
|
||||
self.setLayout(middle)
|
||||
self.show()
|
||||
|
||||
@pyqtSlot()
|
||||
def repro_change(self):
|
||||
repro = self.repro_box.currentText()
|
||||
self.stim_val_box.clear()
|
||||
|
||||
for val in self.data_provider.get_stim_values(repro):
|
||||
self.stim_val_box.addItem(str(val))
|
||||
|
||||
|
||||
class PlotCanvas(FigureCanvas):
|
||||
|
||||
def __init__(self, parent=None, width=5, height=4, dpi=100):
|
||||
fig = Figure(figsize=(width, height), dpi=dpi)
|
||||
def __init__(self, parent=None, dpi=100):
|
||||
fig = Figure(dpi=dpi)
|
||||
self.axes = fig.add_subplot(111)
|
||||
|
||||
FigureCanvas.__init__(self, fig)
|
||||
@ -107,7 +159,7 @@ class PlotCanvas(FigureCanvas):
|
||||
data = np.sin(x*np.pi*2*mean)
|
||||
ax = self.axes
|
||||
ax.clear()
|
||||
ax.plot(data, 'r-')
|
||||
ax.plot(x, data, 'r-')
|
||||
ax.set_title('Sinus Example')
|
||||
self.draw()
|
||||
|
||||
@ -127,7 +179,8 @@ class QVLine(QFrame):
|
||||
|
||||
|
||||
class QFill(QFrame):
|
||||
def __init__(self, maxh=int(2**24)-1, maxw=int(2**24)-1):
|
||||
def __init__(self, maxw=int(2**24)-1, maxh=int(2**24)-1, minw=0, minh=0):
|
||||
super(QFill, self).__init__()
|
||||
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
|
||||
self.setMaximumSize(maxw, maxh)
|
||||
self.setMinimumSize(minw, minh)
|
||||
|
8
main.py
8
main.py
@ -1,14 +1,14 @@
|
||||
|
||||
import sys
|
||||
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton
|
||||
from PyQt5.QtWidgets import QApplication
|
||||
|
||||
from spike_redetection.DataProvider import DataProvider
|
||||
from spike_redetection.SpikeRedetectGui import SpikeRedetectGui
|
||||
from DataProvider import DataProvider
|
||||
from SpikeRedetectGui import SpikeRedetectGui
|
||||
|
||||
|
||||
def main():
|
||||
app = QApplication(sys.argv)
|
||||
data_provider = DataProvider("../data/final_sam/2010-11-08-al-invivo-1")
|
||||
data_provider = DataProvider("../neuronModel/data/final_sam/2010-11-08-al-invivo-1")
|
||||
ex = SpikeRedetectGui(data_provider)
|
||||
sys.exit(app.exec_())
|
||||
|
||||
|
71
redetector.py
Normal file
71
redetector.py
Normal file
@ -0,0 +1,71 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from thunderfish.eventdetection import detect_peaks
|
||||
|
||||
|
||||
def detect_spiketimes(time: np.ndarray, v1, threshold=2.0, min_length=5000, split_step=1000):
|
||||
all_peak_indicies = detect_spike_indices_automatic_split(v1, threshold, min_length, split_step)
|
||||
|
||||
return time[all_peak_indicies]
|
||||
|
||||
|
||||
def detect_spike_indices_automatic_split(v1, threshold, min_length=5000, split_step=1000):
|
||||
split_start = 0
|
||||
step_size = split_step
|
||||
break_threshold = 0.25
|
||||
splits = []
|
||||
|
||||
if len(v1) < min_length:
|
||||
splits = [(0, len(v1))]
|
||||
else:
|
||||
last_max = max(v1[0:min_length])
|
||||
last_min = min(v1[0:min_length])
|
||||
idx = min_length
|
||||
|
||||
while idx < len(v1):
|
||||
if idx + step_size > len(v1):
|
||||
splits.append((split_start, len(v1)))
|
||||
break
|
||||
|
||||
max_similar = abs((max(v1[idx:idx+step_size]) - last_max) / last_max) < break_threshold
|
||||
min_similar = abs((min(v1[idx:idx+step_size]) - last_min) / last_min) < break_threshold
|
||||
|
||||
if not max_similar or not min_similar:
|
||||
# print("new split")
|
||||
end_idx = np.argmin(v1[idx-20:idx+21]) - 20
|
||||
splits.append((split_start, idx+end_idx))
|
||||
split_start = idx+end_idx
|
||||
last_max = max(v1[split_start:split_start + min_length])
|
||||
last_min = min(v1[split_start:split_start + min_length])
|
||||
idx = split_start + min_length
|
||||
continue
|
||||
else:
|
||||
pass
|
||||
# print("elongated!")
|
||||
|
||||
idx += step_size
|
||||
|
||||
if splits[-1][1] != len(v1):
|
||||
splits.append((split_start, len(v1)))
|
||||
|
||||
# plt.plot(v1)
|
||||
|
||||
# for s in splits:
|
||||
# plt.plot(s, (max(v1[s[0]:s[1]]), max(v1[s[0]:s[1]])))
|
||||
|
||||
all_peaks = []
|
||||
for s in splits:
|
||||
first_index = s[0]
|
||||
last_index = s[1]
|
||||
std = np.std(v1[first_index:last_index])
|
||||
peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold)
|
||||
peaks = peaks + first_index
|
||||
# plt.plot(peaks, [np.max(v1[first_index:last_index]) for _ in peaks], 'o')
|
||||
all_peaks.extend(peaks)
|
||||
|
||||
# plt.show()
|
||||
# plt.close()
|
||||
# all_peaks = np.array(all_peaks)
|
||||
|
||||
return all_peaks
|
80
testing.py
Normal file
80
testing.py
Normal file
@ -0,0 +1,80 @@
|
||||
|
||||
from DatParser import DatParser
|
||||
from DataProvider import DataProvider, average_spike_height
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from redetector import detect_spiketimes
|
||||
|
||||
DATA_FOLDER = "../neuronModel/data/final/"
|
||||
|
||||
failure_to_read_sam = ["2012-06-27-an-invivo-1", "2012-12-13-ag-invivo-1"]
|
||||
|
||||
|
||||
def main():
|
||||
for cell in sorted(os.listdir(DATA_FOLDER)):
|
||||
if cell in failure_to_read_sam:
|
||||
continue
|
||||
cell_folder = os.path.join(DATA_FOLDER, cell)
|
||||
data_provider = DataProvider(cell_folder)
|
||||
|
||||
repros = test_getting_repros(data_provider)
|
||||
|
||||
print("\n", cell)
|
||||
for repro in repros:
|
||||
if not repro in data_provider.parser.spike_files.keys():
|
||||
continue
|
||||
|
||||
print(repro)
|
||||
traces, spiketimes, rec_times = data_provider.get_traces_with_spiketimes(repro)
|
||||
sampling_interval = data_provider.parser.get_sampling_interval()
|
||||
for i in range(len(traces)):
|
||||
time = np.arange(len(traces[i])) * sampling_interval - rec_times[0]
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(time, traces[i])
|
||||
plt.eventplot(spiketimes[i], lineoffsets=max(traces[i]) + 1, colors="black")
|
||||
redetect = detect_spiketimes(time, traces[i])
|
||||
plt.eventplot(redetect, lineoffsets=max(traces[i]) + 2, colors="red")
|
||||
# plt.savefig("figures/best_spikes_test/" + cell + "_" + repro + str(i) + ".png")
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
def test_loading_spikes(data_provider: DataProvider, repro):
|
||||
return data_provider.parser.get_spiketimes(repro)
|
||||
|
||||
|
||||
def test_loading_traces(data_provider, repro):
|
||||
return data_provider.get_traces(repro)
|
||||
|
||||
|
||||
def test_getting_repros(data_provider: DataProvider):
|
||||
return data_provider.get_repros()
|
||||
|
||||
|
||||
# def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
|
||||
# ash = np.zeros((len(traces), len(spiketimes)))
|
||||
#
|
||||
# for i, trace in enumerate(traces):
|
||||
# for j, spikes in enumerate(spiketimes):
|
||||
# if len(spikes) <= 1:
|
||||
# ash[i, j] = -np.infty
|
||||
# else:
|
||||
# ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
|
||||
#
|
||||
# return ash
|
||||
#
|
||||
|
||||
# def average_spike_height(spike_train, v1, sampling_rate, before):
|
||||
# indices = np.array([(s - before) / sampling_rate for s in spike_train], dtype=np.int)
|
||||
# v1 = np.array(v1)
|
||||
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
|
||||
# average_height = np.mean(spike_values)
|
||||
#
|
||||
# return average_height
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user