backend base with saving of sorting/traces/spiketimes for the repros

This commit is contained in:
alexanderott 2021-02-19 11:40:19 +01:00
parent 92d14f189b
commit 38ba40ce2f
7 changed files with 396 additions and 68 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
*.dat
/data/
/temp/
/figures/
# Latex output files
*.out
*.aux

View File

@ -6,7 +6,7 @@ import numpy as np
import pyrelacs.DataLoader as Dl
class DatParser():
class DatParser:
def __init__(self, dir_path):
self.base_path = dir_path
@ -15,18 +15,22 @@ class DatParser():
self.baseline_file = self.base_path + "/basespikes1.dat"
self.sam_file = self.base_path + "/samallspikes1.dat"
self.stimuli_file = self.base_path + "/stimuli.dat"
self.__test_data_file_existence__()
self.fi_recording_times = []
self.spike_files = {"BaselineActivity": self.baseline_file,
"FICurve": self.fi_file,
"FileStimulus": self.base_path + "/stimspikes1.dat",
"SAM": self.sam_file}
self.sampling_interval = -1
self.fi_recording_times = []
def has_sam_recordings(self):
return exists(self.sam_file)
self.spiketimes = {}
self.traces = {}
self.metadata = {}
def get_measured_repros(self):
repros = []
for metadata, key, data in Dl.iload(self.stimuli_file):
repros.extend([d["repro"] for d in metadata if "repro" in d.keys()])
repros.extend([d["RePro"] for d in metadata if "RePro" in d.keys()])
return sorted(np.unique(repros))
@ -110,15 +114,6 @@ class DatParser():
return np.array(contrasts)
def traces_available(self) -> bool:
return True
def frequencies_available(self) -> bool:
return False
def spiketimes_available(self) -> bool:
return True
def get_sampling_interval(self):
if self.sampling_interval == -1:
self.__read_sampling_interval__()
@ -131,7 +126,7 @@ class DatParser():
return self.fi_recording_times
def get_baseline_traces(self):
return self.__get_traces__("BaselineActivity")
return self.get_traces("BaselineActivity")
def get_baseline_spiketimes(self):
# TODO change: reading from file -> detect from v1 trace
@ -145,7 +140,7 @@ class DatParser():
return spiketimes
def get_fi_curve_traces(self):
return self.__get_traces__("FICurve")
return self.get_traces("FICurve")
def get_fi_frequency_traces(self):
raise NotImplementedError("Not possible in .dat data type.\n"
@ -234,7 +229,7 @@ class DatParser():
return trans_amplitudes, intensities, spiketimes
def get_sam_traces(self):
return self.__get_traces__("SAM")
return self.get_traces("SAM")
def get_sam_info(self):
contrasts = []
@ -258,7 +253,7 @@ class DatParser():
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
duration = float(metadata[0]["duration"][:-2]) * factor # normally saved in ms? so change it with the factor
duration = float(metadata[0]["duration"][:-2]) * factor # normaly saved in ms? so change it with the factor
contrast = float(metadata[0]["contrast"][:-1]) # in percent
delta_f = float(metadata[0]["deltaf"][:-2])
else:
@ -267,7 +262,7 @@ class DatParser():
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
duration = float(stimulus_dict["duration"][:-2]) * factor # normally saved in ms? so change it with the factor
duration = float(stimulus_dict["duration"][:-2]) * factor # normaly saved in ms? so change it with the factor
contrast = float(stimulus_dict["contrast"][:-1]) # in percent
delta_f = float(stimulus_dict["deltaf"][:-2])
@ -295,7 +290,32 @@ class DatParser():
return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes
def __get_traces__(self, repro):
def get_spiketimes(self, repro, spiketimes_file=""):
if repro in self.spiketimes.keys():
return self.spiketimes[repro], self.metadata[repro]
spiketimes = []
metadata_list = []
warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!")
if spiketimes_file == "":
if repro not in self.spike_files.keys():
raise NotImplementedError("No spiketime file specified in DatParser for repro {}! "
"Please specify spiketimes_file argument or add it to the spike_files dictionary in the Datparser class.".format(repro))
spiketimes_file = self.spike_files[repro]
for metadata, key, data in Dl.iload(spiketimes_file):
metadata_list.append(metadata)
spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds
spiketimes.append(spikes)
self.spiketimes[repro] = spiketimes
self.metadata[repro] = metadata_list
return spiketimes, metadata_list
def get_traces(self, repro, before=0, after=0):
if repro in self.traces.keys():
return self.traces[repro]
time_traces = []
v1_traces = []
eod_traces = []
@ -304,7 +324,7 @@ class DatParser():
nothing = True
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro, before=before, after=after):
nothing = False
time_traces.append(time)
v1_traces.append(x[0])
@ -318,14 +338,9 @@ class DatParser():
warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
warn(warn_msg)
self.traces[repro] = traces
return traces
def __iget_traces__(self, repro):
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
# time, v1, eod, local_eod, stimulus
yield time, x[0], x[1], x[2], x[3]
def __read_fi_recording_times__(self):
delays = []
@ -386,13 +401,3 @@ class DatParser():
"with File:" + self.base_path)
else:
self.sampling_interval = sampling_intervals[0]
def __test_data_file_existence__(self):
if not exists(self.stimuli_file):
raise FileNotFoundError(self.stimuli_file + " file doesn't exist!")
if not exists(self.fi_file):
raise FileNotFoundError(self.fi_file + " file doesn't exist!")
if not exists(self.baseline_file):
raise FileNotFoundError(self.baseline_file + " file doesn't exist!")
# if not exists(self.sam_file):
# raise RuntimeError(self.sam_file + " file doesn't exist!")

View File

@ -1,20 +1,138 @@
from DatParser import DatParser
import numpy as np
from warnings import warn
import matplotlib.pyplot as plt
class DataProvider:
def __init__(self, data_path):
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
self.data_path = data_path
# self.cell = CellData(data_path)
self.parser = DatParser(data_path)
self.repros = self.get_repros()
if len(repros_todo) != 0:
not_available = []
repros_todo = list(np.unique(repros_todo))
for repro in repros_todo:
if repro not in self.repros:
not_available.append(repro)
print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro))
for x in not_available:
repros_todo.remove(x)
self.repros_todo = repros_todo
else:
self.repros_todo = self.repros
self.base_Detection_parameters = base_detection_params
self.thresholds = {}
self.sorting = {}
self.recording_times = {}
def get_repros(self):
pass
return self.parser.get_measured_repros()
def get_unsorted_spiketimes(self, repro, file=""):
return self.parser.get_spiketimes(repro, file)
def get_unsorted_traces(self, repro, before=0, after=0):
return self.parser.get_traces(repro, before, after)
def get_traces_with_spiketimes(self, repro):
if repro in self.sorting.keys():
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
v1_traces = traces[1]
spiketimes = self.get_unsorted_spiketimes(repro)
sorted_spiketimes = np.array(spiketimes)[self.sorting[repro]]
return v1_traces, sorted_spiketimes, self.recording_times[repro]
if repro == "FICurve":
recording_times = self.parser.get_recording_times()
before = abs(recording_times[0])
after = recording_times[3]
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
spiketimes = self.get_unsorted_spiketimes(repro)
else:
before = 0
after = 0
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
spiketimes = self.get_unsorted_spiketimes(repro)
if len(v1_traces) != len(spiketimes):
warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!"
"Returning only according to the number of traces!".format(repro))
ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before)
sorted_spiketimes = []
sorting = []
for i in range(len(v1_traces)):
best = np.argmax(ash[i, :])
sorting.append(best)
sorted_spiketimes.append(spiketimes[best])
self.sorting[repro] = sorting
self.recording_times[repro] = (before, after)
return v1_traces, sorted_spiketimes, (before, after)
def get_stim_values(self, repro):
pass
if repro == "BaselineActivity":
return ["None"]
elif repro == "FICurve":
# TODO other function that just provides the contrasts
return sorted([str(x[0]) for x in self.parser.get_fi_curve_contrasts()])
return ["repro not supported"]
def get_trials(self, repro, stimulus_value):
pass
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
ash = np.zeros((len(traces), len(spiketimes)))
total = len(traces) * len(spiketimes)
count = 0
for i, trace in enumerate(traces):
for j, spikes in enumerate(spiketimes):
count += 1
if count % 50000 == 0:
print("{} / {}".format(count, total))
if len(spikes) <= 1:
ash[i, j] = -np.infty
else:
ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
return ash
def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before):
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
indices = (spike_train + before) / sampling_rate
indices = np.array(indices, dtype=np.int)
if len(indices) <= 1:
return -np.infty
# [v1[i] for i in indices if 0 <= i < len(v1)]
applicable_indices = indices[(indices < len(v1)) & (indices > 0)]
spike_values = v1[applicable_indices]
average_height = np.mean(spike_values)
return average_height
# SLOW:
# def average_spike_height(spike_train, v1, sampling_rate, before):
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
# if len(indices) <= 1:
# return -np.infty
# v1 = np.array(v1)
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
# average_height = np.mean(spike_values)
#
# return average_height

View File

@ -1,18 +1,19 @@
from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
QDoubleSpinBox, QComboBox
QDoubleSpinBox, QComboBox, QSpinBox
from PyQt5.QtCore import pyqtSlot
import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from DataProvider import DataProvider
class SpikeRedetectGui(QWidget):
def __init__(self, data_provider):
def __init__(self, data_provider: DataProvider):
super().__init__()
self.data_provider = data_provider
self.title = 'Spike Redetection'
@ -29,36 +30,80 @@ class SpikeRedetectGui(QWidget):
# Middle:
middle = QHBoxLayout()
# Canvas for matplotlib figure
m = PlotCanvas(self, width=5, height=4)
# Canvas Area for matplotlib figure
plot_area = QFrame()
plot_area_layout = QVBoxLayout()
m = PlotCanvas(self)
m.move(0, 0)
middle.addWidget(m)
plot_area_layout.addWidget(m)
# plot area buttons
plot_area_buttons = QFrame()
plot_area_buttons_layout = QHBoxLayout()
plot_area_buttons.setLayout(plot_area_buttons_layout)
plot_area_layout.addWidget(plot_area_buttons)
button = QPushButton('Button1', self)
button.setToolTip('A nice button!')
button.clicked.connect(lambda: threshold_spinbox.setValue(1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button2', self)
button.setToolTip('Another nice button!')
button.clicked.connect(lambda: threshold_spinbox.setValue(2))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button3', self)
button.setToolTip('Even more nice buttons!')
button.clicked.connect(lambda: threshold_spinbox.setValue(3))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button4', self)
button.setToolTip('Even more nice buttons!')
button.clicked.connect(lambda: threshold_spinbox.setValue(4))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
plot_area.setLayout(plot_area_layout)
middle.addWidget(plot_area)
middle.addWidget(QVLine())
# Side (options) panel
panel = QFrame()
panel.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
panel.setMaximumWidth(200)
panel_layout = QVBoxLayout()
button = QPushButton('Button!', self)
button.setToolTip('A nice button!')
button.clicked.connect(lambda: threshold_spinbox.setValue(1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
panel_layout.addWidget(button)
stim_val_label = QLabel("Stimulus value:")
self.stim_val_box = QComboBox()
repro_label = QLabel("Repro:")
panel_layout.addWidget(repro_label)
self.repro_box = QComboBox()
self.repro_box.addItem("placeholder repro")
panel_layout.addWidget(self.repro_box)
self.repro_box.currentTextChanged.connect(self.repro_change)
for repro in self.data_provider.get_repros():
self.repro_box.addItem(repro)
stim_val_label = QLabel("Stimulus value:")
panel_layout.addWidget(self.repro_box)
panel_layout.addWidget(stim_val_label)
self.stim_val_box = QComboBox()
self.stim_val_box.addItem("placeholder stim_value")
panel_layout.addWidget(self.stim_val_box)
filler = QFill(maxh=20)
trial_label = QLabel("Trial:")
panel_layout.addWidget(trial_label)
threshold_spinbox = QSpinBox(self)
threshold_spinbox.setValue(1)
threshold_spinbox.setSingleStep(1)
threshold_spinbox.valueChanged.connect()
panel_layout.addWidget(threshold_spinbox)
filler = QFill(minh=200)
panel_layout.addWidget(filler)
self.status_label = QLabel("Done x/15 Stimulus Values")
@ -78,18 +123,25 @@ class SpikeRedetectGui(QWidget):
button.setToolTip('Accept the threshold for current stimulus value')
panel_layout.addWidget(button)
panel.setLayout(panel_layout)
middle.addWidget(panel)
self.setLayout(middle)
self.show()
@pyqtSlot()
def repro_change(self):
repro = self.repro_box.currentText()
self.stim_val_box.clear()
for val in self.data_provider.get_stim_values(repro):
self.stim_val_box.addItem(str(val))
class PlotCanvas(FigureCanvas):
def __init__(self, parent=None, width=5, height=4, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi)
def __init__(self, parent=None, dpi=100):
fig = Figure(dpi=dpi)
self.axes = fig.add_subplot(111)
FigureCanvas.__init__(self, fig)
@ -107,7 +159,7 @@ class PlotCanvas(FigureCanvas):
data = np.sin(x*np.pi*2*mean)
ax = self.axes
ax.clear()
ax.plot(data, 'r-')
ax.plot(x, data, 'r-')
ax.set_title('Sinus Example')
self.draw()
@ -127,7 +179,8 @@ class QVLine(QFrame):
class QFill(QFrame):
def __init__(self, maxh=int(2**24)-1, maxw=int(2**24)-1):
def __init__(self, maxw=int(2**24)-1, maxh=int(2**24)-1, minw=0, minh=0):
super(QFill, self).__init__()
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.setMaximumSize(maxw, maxh)
self.setMinimumSize(minw, minh)

View File

@ -1,14 +1,14 @@
import sys
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton
from PyQt5.QtWidgets import QApplication
from spike_redetection.DataProvider import DataProvider
from spike_redetection.SpikeRedetectGui import SpikeRedetectGui
from DataProvider import DataProvider
from SpikeRedetectGui import SpikeRedetectGui
def main():
app = QApplication(sys.argv)
data_provider = DataProvider("../data/final_sam/2010-11-08-al-invivo-1")
data_provider = DataProvider("../neuronModel/data/final_sam/2010-11-08-al-invivo-1")
ex = SpikeRedetectGui(data_provider)
sys.exit(app.exec_())

71
redetector.py Normal file
View File

@ -0,0 +1,71 @@
import numpy as np
from thunderfish.eventdetection import detect_peaks
def detect_spiketimes(time: np.ndarray, v1, threshold=2.0, min_length=5000, split_step=1000):
all_peak_indicies = detect_spike_indices_automatic_split(v1, threshold, min_length, split_step)
return time[all_peak_indicies]
def detect_spike_indices_automatic_split(v1, threshold, min_length=5000, split_step=1000):
split_start = 0
step_size = split_step
break_threshold = 0.25
splits = []
if len(v1) < min_length:
splits = [(0, len(v1))]
else:
last_max = max(v1[0:min_length])
last_min = min(v1[0:min_length])
idx = min_length
while idx < len(v1):
if idx + step_size > len(v1):
splits.append((split_start, len(v1)))
break
max_similar = abs((max(v1[idx:idx+step_size]) - last_max) / last_max) < break_threshold
min_similar = abs((min(v1[idx:idx+step_size]) - last_min) / last_min) < break_threshold
if not max_similar or not min_similar:
# print("new split")
end_idx = np.argmin(v1[idx-20:idx+21]) - 20
splits.append((split_start, idx+end_idx))
split_start = idx+end_idx
last_max = max(v1[split_start:split_start + min_length])
last_min = min(v1[split_start:split_start + min_length])
idx = split_start + min_length
continue
else:
pass
# print("elongated!")
idx += step_size
if splits[-1][1] != len(v1):
splits.append((split_start, len(v1)))
# plt.plot(v1)
# for s in splits:
# plt.plot(s, (max(v1[s[0]:s[1]]), max(v1[s[0]:s[1]])))
all_peaks = []
for s in splits:
first_index = s[0]
last_index = s[1]
std = np.std(v1[first_index:last_index])
peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold)
peaks = peaks + first_index
# plt.plot(peaks, [np.max(v1[first_index:last_index]) for _ in peaks], 'o')
all_peaks.extend(peaks)
# plt.show()
# plt.close()
# all_peaks = np.array(all_peaks)
return all_peaks

80
testing.py Normal file
View File

@ -0,0 +1,80 @@
from DatParser import DatParser
from DataProvider import DataProvider, average_spike_height
import os
import numpy as np
import matplotlib.pyplot as plt
from redetector import detect_spiketimes
DATA_FOLDER = "../neuronModel/data/final/"
failure_to_read_sam = ["2012-06-27-an-invivo-1", "2012-12-13-ag-invivo-1"]
def main():
for cell in sorted(os.listdir(DATA_FOLDER)):
if cell in failure_to_read_sam:
continue
cell_folder = os.path.join(DATA_FOLDER, cell)
data_provider = DataProvider(cell_folder)
repros = test_getting_repros(data_provider)
print("\n", cell)
for repro in repros:
if not repro in data_provider.parser.spike_files.keys():
continue
print(repro)
traces, spiketimes, rec_times = data_provider.get_traces_with_spiketimes(repro)
sampling_interval = data_provider.parser.get_sampling_interval()
for i in range(len(traces)):
time = np.arange(len(traces[i])) * sampling_interval - rec_times[0]
plt.figure(figsize=(10, 5))
plt.plot(time, traces[i])
plt.eventplot(spiketimes[i], lineoffsets=max(traces[i]) + 1, colors="black")
redetect = detect_spiketimes(time, traces[i])
plt.eventplot(redetect, lineoffsets=max(traces[i]) + 2, colors="red")
# plt.savefig("figures/best_spikes_test/" + cell + "_" + repro + str(i) + ".png")
plt.show()
plt.close()
def test_loading_spikes(data_provider: DataProvider, repro):
return data_provider.parser.get_spiketimes(repro)
def test_loading_traces(data_provider, repro):
return data_provider.get_traces(repro)
def test_getting_repros(data_provider: DataProvider):
return data_provider.get_repros()
# def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
# ash = np.zeros((len(traces), len(spiketimes)))
#
# for i, trace in enumerate(traces):
# for j, spikes in enumerate(spiketimes):
# if len(spikes) <= 1:
# ash[i, j] = -np.infty
# else:
# ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
#
# return ash
#
# def average_spike_height(spike_train, v1, sampling_rate, before):
# indices = np.array([(s - before) / sampling_rate for s in spike_train], dtype=np.int)
# v1 = np.array(v1)
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
# average_height = np.mean(spike_values)
#
# return average_height
if __name__ == '__main__':
main()