general working state before stim value grouping

This commit is contained in:
alexanderott 2021-03-07 12:50:30 +01:00
parent b0dfa22ef4
commit 8b10dba43b
6 changed files with 139 additions and 56 deletions

9
.gitignore vendored
View File

@ -1,10 +1,17 @@
*.dat *.dat
/data/ /data/
/temp/ /temp/
/figures/ /figures/
.idea/
__pycache__/
# Latex output files # Latex output files
*.out *.out
*.aux *.aux
*.log *.log
*.synctex.gz *.synctex.gz
*.toc *.toc
# other
*.csv
*.txt

View File

@ -6,7 +6,7 @@ from warnings import warn
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
class DataProvider: class Controller:
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)): def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
self.data_path = data_path self.data_path = data_path
@ -28,12 +28,31 @@ class DataProvider:
self.repros_todo = self.repros self.repros_todo = self.repros
self.base_detection_parameters = base_detection_params self.base_detection_parameters = base_detection_params
self.thresholds = {}
self.sampling_interval = self.parser.get_sampling_interval() self.sampling_interval = self.parser.get_sampling_interval()
self.sorting = {} self.sorting = {}
self.thresholds = {}
self.recording_times = {} self.recording_times = {}
def save_parameters(self):
header = "trial_index,threshold,min_window,step_size\n"
for repro in self.thresholds.keys():
# TODO change '.' to self.data_path after testing:
with open("." + "/{}_thresholds.csv".format(repro), "w") as threshold_file:
threshold_file.write(header)
for i in range(len(self.sorting[repro])):
if len(self.thresholds[repro][i]) == 0:
# missing threshold info
threshold_file.write("{},{},{},{}\n".format(i, -1, -1, -1))
else:
# format and write given thresholding parameters
thresh = self.thresholds[repro][i][0]
min_window = self.thresholds[repro][i][1]
step = self.thresholds[repro][i][2]
threshold_file.write("{},{},{},{}\n".format(i, thresh, min_window, step))
print("Thresholds saved!")
# TODO save redetected spikes:
def get_repros(self): def get_repros(self):
return self.parser.get_measured_repros() return self.parser.get_measured_repros()
@ -94,6 +113,12 @@ class DataProvider:
def get_trials(self, repro, stimulus_value): def get_trials(self, repro, stimulus_value):
pass pass
def set_redetection_params(self, repro, trial_idx, params):
if repro not in self.thresholds.keys():
self.thresholds[repro] = [(),]*len(self.sorting[repro])
self.thresholds[repro][trial_idx] = params
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before): def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
ash = np.zeros((len(traces), len(spiketimes))) ash = np.zeros((len(traces), len(spiketimes)))

View File

@ -131,7 +131,7 @@ class DatParser:
def get_baseline_spiketimes(self): def get_baseline_spiketimes(self):
# TODO change: reading from file -> detect from v1 trace # TODO change: reading from file -> detect from v1 trace
spiketimes = [] spiketimes = []
warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.") # warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")
for metadata, key, data in Dl.iload(self.baseline_file): for metadata, key, data in Dl.iload(self.baseline_file):
spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds
@ -296,7 +296,7 @@ class DatParser:
spiketimes = [] spiketimes = []
metadata_list = [] metadata_list = []
warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!") # warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!")
if spiketimes_file == "": if spiketimes_file == "":
if repro not in self.spike_files.keys(): if repro not in self.spike_files.keys():

View File

@ -1,30 +1,29 @@
from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\ from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\ QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
QDoubleSpinBox, QComboBox, QSpinBox QDoubleSpinBox, QComboBox, QSpinBox, QCheckBox
from PyQt5.QtCore import pyqtSlot from PyQt5.QtCore import pyqtSlot
import numpy as np import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure from matplotlib.figure import Figure
from DataProvider import DataProvider from Controller import Controller
from redetector import detect_spiketimes from redetector import detect_spiketimes
class SpikeRedetectGui(QWidget): class SpikeRedetectGui(QWidget):
def __init__(self, data_provider: DataProvider): def __init__(self, controller: Controller):
super().__init__() super().__init__()
self.data_provider = data_provider self.controller = controller
self.title = 'Spike Redetection' self.title = 'Spike Redetection'
self.left = 10 self.left = 10
self.top = 10 self.top = 10
self.width = 640 self.width = 1500
self.height = 400 self.height = 800
self.trial_idx = 0 self.trial_idx = 0
self.initUI() self.initUI()
def initUI(self): def initUI(self):
@ -35,7 +34,6 @@ class SpikeRedetectGui(QWidget):
middle = QHBoxLayout() middle = QHBoxLayout()
# Canvas Area for matplotlib figure # Canvas Area for matplotlib figure
plot_area = QFrame() plot_area = QFrame()
plot_area_layout = QVBoxLayout() plot_area_layout = QVBoxLayout()
self.canvas = PlotCanvas(self) self.canvas = PlotCanvas(self)
@ -48,26 +46,26 @@ class SpikeRedetectGui(QWidget):
plot_area_buttons.setLayout(plot_area_buttons_layout) plot_area_buttons.setLayout(plot_area_buttons_layout)
plot_area_layout.addWidget(plot_area_buttons) plot_area_layout.addWidget(plot_area_buttons)
button = QPushButton('Button1', self) button = QPushButton('previous stimulus value', self)
button.setToolTip('A nice button!') button.setToolTip('A nice button!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(1)) button.clicked.connect(lambda: self.threshold_spinbox.setValue(1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button) plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button2', self) button = QPushButton('previous trial', self)
button.setToolTip('Another nice button!') button.setToolTip('Another nice button!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(2)) button.clicked.connect(lambda: self.trial_change(self.trial_idx - 1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button) plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button3', self) button = QPushButton('next trial', self)
button.setToolTip('Even more nice buttons!') button.setToolTip('Even more nice buttons!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(3)) button.clicked.connect(lambda: self.trial_change(self.trial_idx + 1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button) plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button4', self) button = QPushButton('next stimulus value', self)
button.setToolTip('Even more nice buttons!') button.setToolTip('But this is not a nice button!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(4)) button.clicked.connect(lambda: self.threshold_spinbox.setValue(4))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button) plot_area_buttons_layout.addWidget(button)
@ -85,40 +83,43 @@ class SpikeRedetectGui(QWidget):
stim_val_label = QLabel("Stimulus value:")
self.stim_val_box = QComboBox()
repro_label = QLabel("Repro:") repro_label = QLabel("Repro:")
panel_layout.addWidget(repro_label) panel_layout.addWidget(repro_label)
self.repro_box = QComboBox()
for repro in self.data_provider.get_repros(): self.repro_box = QComboBox()
self.repro_box.addItem(repro) self.repro_box.addItems(self.controller.get_repros())
self.repro_box.currentTextChanged.connect(self.repro_change) self.repro_box.currentTextChanged.connect(self.repro_change)
panel_layout.addWidget(self.repro_box) panel_layout.addWidget(self.repro_box)
stim_val_label = QLabel("Stimulus value:")
panel_layout.addWidget(stim_val_label) panel_layout.addWidget(stim_val_label)
self.stim_val_box = QComboBox()
panel_layout.addWidget(self.stim_val_box) panel_layout.addWidget(self.stim_val_box)
trial_label = QLabel("Trial:") self.grouping_checkbox = QCheckBox()
panel_layout.addWidget(trial_label) self.grouping_checkbox.setText("Group by Stimulus Value:")
trial_spinbox = QSpinBox(self) panel_layout.addWidget(self.grouping_checkbox)
trial_spinbox.setValue(1)
trial_spinbox.setSingleStep(1)
trial_spinbox.valueChanged.connect(lambda: self.trial_change(trial_spinbox.value()))
panel_layout.addWidget(trial_spinbox)
filler = QFill(minh=200)
panel_layout.addWidget(filler)
self.status_label = QLabel("Done x/15 Stimulus Values") trial_label = QLabel("Trial:")
panel_layout.addWidget(self.status_label) panel_layout.addWidget(trial_label)
self.trial_spinbox = QSpinBox(self)
self.trial_spinbox.setValue(self.trial_idx)
self.trial_spinbox.setSingleStep(1)
self.trial_spinbox.valueChanged.connect(lambda: self.trial_change(self.trial_spinbox.value()))
panel_layout.addWidget(self.trial_spinbox)
# self.status_label = QLabel("Done x/15 Stimulus Values")
# panel_layout.addWidget(self.status_label)
filler = QFill() filler = QFill()
panel_layout.addWidget(filler) panel_layout.addWidget(filler)
threshold_label = QLabel("Threshold:") threshold_label = QLabel("Threshold:")
panel_layout.addWidget(threshold_label) panel_layout.addWidget(threshold_label)
self.threshold_spinbox = QDoubleSpinBox(self) self.threshold_spinbox = QDoubleSpinBox(self)
self.threshold_spinbox.setValue(self.data_provider.base_detection_parameters[0]) self.threshold_spinbox.setValue(self.controller.base_detection_parameters[0])
self.threshold_spinbox.setSingleStep(0.5) self.threshold_spinbox.setSingleStep(0.5)
self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed()) self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.threshold_spinbox) panel_layout.addWidget(self.threshold_spinbox)
@ -127,7 +128,7 @@ class SpikeRedetectGui(QWidget):
panel_layout.addWidget(window_label) panel_layout.addWidget(window_label)
self.window_spinbox = QSpinBox(self) self.window_spinbox = QSpinBox(self)
self.window_spinbox.setMaximum(2**21) self.window_spinbox.setMaximum(2**21)
self.window_spinbox.setValue(self.data_provider.base_detection_parameters[1]) self.window_spinbox.setValue(self.controller.base_detection_parameters[1])
self.window_spinbox.setSingleStep(500) self.window_spinbox.setSingleStep(500)
self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed()) self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.window_spinbox) panel_layout.addWidget(self.window_spinbox)
@ -136,43 +137,77 @@ class SpikeRedetectGui(QWidget):
panel_layout.addWidget(step_label) panel_layout.addWidget(step_label)
self.step_spinbox = QSpinBox(self) self.step_spinbox = QSpinBox(self)
self.step_spinbox.setMaximum(2 ** 21) self.step_spinbox.setMaximum(2 ** 21)
self.step_spinbox.setValue(self.data_provider.base_detection_parameters[2]) self.step_spinbox.setValue(self.controller.base_detection_parameters[2])
self.step_spinbox.setSingleStep(200) self.step_spinbox.setSingleStep(200)
self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed()) self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.step_spinbox) panel_layout.addWidget(self.step_spinbox)
button = QPushButton('Accept!', self) button = QPushButton('Accept!', self)
button.setToolTip('Accept the threshold for current stimulus value') button.setToolTip('Accept the threshold for current stimulus value')
button.clicked.connect(self.accept_redetection)
panel_layout.addWidget(button)
button = QPushButton('Save thresholds!', self)
button.setToolTip('Save accepted thresholds!')
button.clicked.connect(self.save_threshold_parameters)
panel_layout.addWidget(button) panel_layout.addWidget(button)
panel.setLayout(panel_layout) panel.setLayout(panel_layout)
middle.addWidget(panel) middle.addWidget(panel)
self.setLayout(middle) self.setLayout(middle)
# init the first repro:
self.repro_change()
self.show() self.show()
@pyqtSlot() @pyqtSlot()
def redetection_changed(self): def redetection_changed(self):
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
@pyqtSlot() @pyqtSlot()
def trial_change(self, new_trial_idx): def trial_change(self, new_trial_idx):
# TODO test if in range of trials! if new_trial_idx == self.trial_idx:
return
traces, spiketimes, recording_times = self.controller.get_traces_with_spiketimes(self.repro_box.currentText())
if not 0 <= new_trial_idx < len(spiketimes):
print("trial_change():Out of range trial index!")
return
self.trial_idx = new_trial_idx self.trial_idx = new_trial_idx
if new_trial_idx != self.trial_spinbox.value():
self.trial_spinbox.setValue(new_trial_idx)
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
@pyqtSlot() @pyqtSlot()
def repro_change(self): def repro_change(self):
repro = self.repro_box.currentText() repro = self.repro_box.currentText()
# reset trials and stim values
self.trial_change(0)
self.stim_val_box.clear() self.stim_val_box.clear()
for val in self.data_provider.get_stim_values(repro): for val in self.controller.get_stim_values(repro):
self.stim_val_box.addItem(str(val)) self.stim_val_box.addItem(str(val))
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
@pyqtSlot()
def accept_redetection(self):
params = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.controller.set_redetection_params(self.repro_box.currentText(), self.trial_idx, params)
self.trial_change(self.trial_idx + 1)
@pyqtSlot()
def save_threshold_parameters(self):
self.controller.save_parameters()
class PlotCanvas(FigureCanvas): class PlotCanvas(FigureCanvas):
@ -189,6 +224,7 @@ class PlotCanvas(FigureCanvas):
QSizePolicy.Expanding) QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self) FigureCanvas.updateGeometry(self)
self.current_repro = None
self.mouse_button_pressed = False self.mouse_button_pressed = False
self.mouse_button = "-1" self.mouse_button = "-1"
@ -249,7 +285,6 @@ class PlotCanvas(FigureCanvas):
self.axes.set_ylim(new_ylimits) self.axes.set_ylim(new_ylimits)
self.draw() self.draw()
def onclick(self, event): def onclick(self, event):
if event.button in (1, 3): if event.button in (1, 3):
@ -272,7 +307,7 @@ class PlotCanvas(FigureCanvas):
self.mouse_button_pressed = False self.mouse_button_pressed = False
@pyqtSlot() @pyqtSlot()
def plot(self, trial_idx, repro, data_provider: DataProvider, redetection_vars): def plot(self, trial_idx, repro, data_provider: Controller, redetection_vars):
traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro) traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro)
trace = traces[trial_idx] trace = traces[trial_idx]
spiketimes = spiketimes[trial_idx] spiketimes = spiketimes[trial_idx]
@ -289,8 +324,12 @@ class PlotCanvas(FigureCanvas):
redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2]) redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2])
ax.eventplot(redetect, lineoffsets=max(trace) + 2, colors="red") ax.eventplot(redetect, lineoffsets=max(trace) + 2, colors="red")
ax.set_title('Trial XYZ') ax.set_title('Trial XYZ')
ax.set_xlim(xlim)
ax.set_ylim(ylim) if self.current_repro == repro:
ax.set_xlim(xlim)
ax.set_ylim(ylim)
else:
self.current_repro = repro
self.draw() self.draw()

18
main.py
View File

@ -1,17 +1,29 @@
import sys import sys
from PyQt5.QtWidgets import QApplication import os
from DataProvider import DataProvider from PyQt5.QtWidgets import QApplication, QFileDialog
from Controller import Controller
from SpikeRedetectGui import SpikeRedetectGui from SpikeRedetectGui import SpikeRedetectGui
test_file = "../neuronModel/data/final_sam/2010-11-08-al-invivo-1"
def main(): def main():
app = QApplication(sys.argv) app = QApplication(sys.argv)
data_provider = DataProvider("../neuronModel/data/final_sam/2010-11-08-al-invivo-1") data_path = QFileDialog.getExistingDirectory(caption='Select a directory')
if os.path.isdir(data_path):
data_provider = Controller(data_path)
else:
print("Cell loading didn't work!\n Cell folder: {}".format(data_path))
return
#data_provider = Controller(test_file)
ex = SpikeRedetectGui(data_provider) ex = SpikeRedetectGui(data_provider)
sys.exit(app.exec_()) sys.exit(app.exec_())
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,6 +1,6 @@
from DatParser import DatParser from DatParser import DatParser
from DataProvider import DataProvider, average_spike_height from Controller import Controller, average_spike_height
import os import os
import numpy as np import numpy as np
@ -18,7 +18,7 @@ def main():
if cell in failure_to_read_sam: if cell in failure_to_read_sam:
continue continue
cell_folder = os.path.join(DATA_FOLDER, cell) cell_folder = os.path.join(DATA_FOLDER, cell)
data_provider = DataProvider(cell_folder) data_provider = Controller(cell_folder)
repros = test_getting_repros(data_provider) repros = test_getting_repros(data_provider)
@ -42,7 +42,7 @@ def main():
plt.close() plt.close()
def test_loading_spikes(data_provider: DataProvider, repro): def test_loading_spikes(data_provider: Controller, repro):
return data_provider.parser.get_spiketimes(repro) return data_provider.parser.get_spiketimes(repro)
@ -50,7 +50,7 @@ def test_loading_traces(data_provider, repro):
return data_provider.get_traces(repro) return data_provider.get_traces(repro)
def test_getting_repros(data_provider: DataProvider): def test_getting_repros(data_provider: Controller):
return data_provider.get_repros() return data_provider.get_repros()