diff --git a/.gitignore b/.gitignore index 5e37a94..7c3f129 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,17 @@ *.dat + /data/ /temp/ /figures/ +.idea/ +__pycache__/ # Latex output files *.out *.aux *.log *.synctex.gz -*.toc \ No newline at end of file +*.toc + +# other +*.csv +*.txt \ No newline at end of file diff --git a/DataProvider.py b/Controller.py similarity index 79% rename from DataProvider.py rename to Controller.py index f3c0fb1..9ac24d9 100644 --- a/DataProvider.py +++ b/Controller.py @@ -6,7 +6,7 @@ from warnings import warn import matplotlib.pyplot as plt -class DataProvider: +class Controller: def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)): self.data_path = data_path @@ -28,12 +28,31 @@ class DataProvider: self.repros_todo = self.repros self.base_detection_parameters = base_detection_params - self.thresholds = {} self.sampling_interval = self.parser.get_sampling_interval() self.sorting = {} + self.thresholds = {} self.recording_times = {} + def save_parameters(self): + header = "trial_index,threshold,min_window,step_size\n" + for repro in self.thresholds.keys(): + # TODO change '.' to self.data_path after testing: + with open("." + "/{}_thresholds.csv".format(repro), "w") as threshold_file: + threshold_file.write(header) + for i in range(len(self.sorting[repro])): + if len(self.thresholds[repro][i]) == 0: + # missing threshold info + threshold_file.write("{},{},{},{}\n".format(i, -1, -1, -1)) + else: + # format and write given thresholding parameters + thresh = self.thresholds[repro][i][0] + min_window = self.thresholds[repro][i][1] + step = self.thresholds[repro][i][2] + threshold_file.write("{},{},{},{}\n".format(i, thresh, min_window, step)) + print("Thresholds saved!") + # TODO save redetected spikes: + def get_repros(self): return self.parser.get_measured_repros() @@ -94,6 +113,12 @@ class DataProvider: def get_trials(self, repro, stimulus_value): pass + def set_redetection_params(self, repro, trial_idx, params): + if repro not in self.thresholds.keys(): + self.thresholds[repro] = [(),]*len(self.sorting[repro]) + + self.thresholds[repro][trial_idx] = params + def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before): ash = np.zeros((len(traces), len(spiketimes))) diff --git a/DatParser.py b/DatParser.py index 6737908..382d08b 100644 --- a/DatParser.py +++ b/DatParser.py @@ -131,7 +131,7 @@ class DatParser: def get_baseline_spiketimes(self): # TODO change: reading from file -> detect from v1 trace spiketimes = [] - warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.") + # warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.") for metadata, key, data in Dl.iload(self.baseline_file): spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds @@ -296,7 +296,7 @@ class DatParser: spiketimes = [] metadata_list = [] - warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!") + # warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!") if spiketimes_file == "": if repro not in self.spike_files.keys(): diff --git a/SpikeRedetectGui.py b/SpikeRedetectGui.py index afb8393..71766d2 100644 --- a/SpikeRedetectGui.py +++ b/SpikeRedetectGui.py @@ -1,30 +1,29 @@ from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\ QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\ - QDoubleSpinBox, QComboBox, QSpinBox + QDoubleSpinBox, QComboBox, QSpinBox, QCheckBox from PyQt5.QtCore import pyqtSlot import numpy as np from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.figure import Figure -from DataProvider import DataProvider +from Controller import Controller from redetector import detect_spiketimes class SpikeRedetectGui(QWidget): - def __init__(self, data_provider: DataProvider): + def __init__(self, controller: Controller): super().__init__() - self.data_provider = data_provider + self.controller = controller self.title = 'Spike Redetection' self.left = 10 self.top = 10 - self.width = 640 - self.height = 400 + self.width = 1500 + self.height = 800 self.trial_idx = 0 - self.initUI() def initUI(self): @@ -35,7 +34,6 @@ class SpikeRedetectGui(QWidget): middle = QHBoxLayout() # Canvas Area for matplotlib figure - plot_area = QFrame() plot_area_layout = QVBoxLayout() self.canvas = PlotCanvas(self) @@ -48,26 +46,26 @@ class SpikeRedetectGui(QWidget): plot_area_buttons.setLayout(plot_area_buttons_layout) plot_area_layout.addWidget(plot_area_buttons) - button = QPushButton('Button1', self) + button = QPushButton('previous stimulus value', self) button.setToolTip('A nice button!') button.clicked.connect(lambda: self.threshold_spinbox.setValue(1)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) - button = QPushButton('Button2', self) + button = QPushButton('previous trial', self) button.setToolTip('Another nice button!') - button.clicked.connect(lambda: self.threshold_spinbox.setValue(2)) + button.clicked.connect(lambda: self.trial_change(self.trial_idx - 1)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) - button = QPushButton('Button3', self) + button = QPushButton('next trial', self) button.setToolTip('Even more nice buttons!') - button.clicked.connect(lambda: self.threshold_spinbox.setValue(3)) + button.clicked.connect(lambda: self.trial_change(self.trial_idx + 1)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) - button = QPushButton('Button4', self) - button.setToolTip('Even more nice buttons!') + button = QPushButton('next stimulus value', self) + button.setToolTip('But this is not a nice button!') button.clicked.connect(lambda: self.threshold_spinbox.setValue(4)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) @@ -85,40 +83,43 @@ class SpikeRedetectGui(QWidget): - stim_val_label = QLabel("Stimulus value:") - self.stim_val_box = QComboBox() - repro_label = QLabel("Repro:") panel_layout.addWidget(repro_label) - self.repro_box = QComboBox() - for repro in self.data_provider.get_repros(): - self.repro_box.addItem(repro) + self.repro_box = QComboBox() + self.repro_box.addItems(self.controller.get_repros()) self.repro_box.currentTextChanged.connect(self.repro_change) + panel_layout.addWidget(self.repro_box) + + stim_val_label = QLabel("Stimulus value:") panel_layout.addWidget(stim_val_label) + + self.stim_val_box = QComboBox() panel_layout.addWidget(self.stim_val_box) - trial_label = QLabel("Trial:") - panel_layout.addWidget(trial_label) - trial_spinbox = QSpinBox(self) - trial_spinbox.setValue(1) - trial_spinbox.setSingleStep(1) - trial_spinbox.valueChanged.connect(lambda: self.trial_change(trial_spinbox.value())) - panel_layout.addWidget(trial_spinbox) + self.grouping_checkbox = QCheckBox() + self.grouping_checkbox.setText("Group by Stimulus Value:") + panel_layout.addWidget(self.grouping_checkbox) - filler = QFill(minh=200) - panel_layout.addWidget(filler) - self.status_label = QLabel("Done x/15 Stimulus Values") - panel_layout.addWidget(self.status_label) + trial_label = QLabel("Trial:") + panel_layout.addWidget(trial_label) + self.trial_spinbox = QSpinBox(self) + self.trial_spinbox.setValue(self.trial_idx) + self.trial_spinbox.setSingleStep(1) + self.trial_spinbox.valueChanged.connect(lambda: self.trial_change(self.trial_spinbox.value())) + panel_layout.addWidget(self.trial_spinbox) + + # self.status_label = QLabel("Done x/15 Stimulus Values") + # panel_layout.addWidget(self.status_label) filler = QFill() panel_layout.addWidget(filler) threshold_label = QLabel("Threshold:") panel_layout.addWidget(threshold_label) self.threshold_spinbox = QDoubleSpinBox(self) - self.threshold_spinbox.setValue(self.data_provider.base_detection_parameters[0]) + self.threshold_spinbox.setValue(self.controller.base_detection_parameters[0]) self.threshold_spinbox.setSingleStep(0.5) self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed()) panel_layout.addWidget(self.threshold_spinbox) @@ -127,7 +128,7 @@ class SpikeRedetectGui(QWidget): panel_layout.addWidget(window_label) self.window_spinbox = QSpinBox(self) self.window_spinbox.setMaximum(2**21) - self.window_spinbox.setValue(self.data_provider.base_detection_parameters[1]) + self.window_spinbox.setValue(self.controller.base_detection_parameters[1]) self.window_spinbox.setSingleStep(500) self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed()) panel_layout.addWidget(self.window_spinbox) @@ -136,43 +137,77 @@ class SpikeRedetectGui(QWidget): panel_layout.addWidget(step_label) self.step_spinbox = QSpinBox(self) self.step_spinbox.setMaximum(2 ** 21) - self.step_spinbox.setValue(self.data_provider.base_detection_parameters[2]) + self.step_spinbox.setValue(self.controller.base_detection_parameters[2]) self.step_spinbox.setSingleStep(200) self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed()) panel_layout.addWidget(self.step_spinbox) button = QPushButton('Accept!', self) button.setToolTip('Accept the threshold for current stimulus value') + button.clicked.connect(self.accept_redetection) + panel_layout.addWidget(button) + + button = QPushButton('Save thresholds!', self) + button.setToolTip('Save accepted thresholds!') + button.clicked.connect(self.save_threshold_parameters) panel_layout.addWidget(button) panel.setLayout(panel_layout) middle.addWidget(panel) self.setLayout(middle) + + # init the first repro: + self.repro_change() + self.show() @pyqtSlot() def redetection_changed(self): redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) - self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) + self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection) @pyqtSlot() def trial_change(self, new_trial_idx): - # TODO test if in range of trials! + if new_trial_idx == self.trial_idx: + return + + traces, spiketimes, recording_times = self.controller.get_traces_with_spiketimes(self.repro_box.currentText()) + if not 0 <= new_trial_idx < len(spiketimes): + print("trial_change():Out of range trial index!") + return self.trial_idx = new_trial_idx + + if new_trial_idx != self.trial_spinbox.value(): + self.trial_spinbox.setValue(new_trial_idx) + redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) - self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) + self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection) @pyqtSlot() def repro_change(self): repro = self.repro_box.currentText() + + # reset trials and stim values + self.trial_change(0) self.stim_val_box.clear() - for val in self.data_provider.get_stim_values(repro): + for val in self.controller.get_stim_values(repro): self.stim_val_box.addItem(str(val)) redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) - self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) + self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection) + + @pyqtSlot() + def accept_redetection(self): + params = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) + self.controller.set_redetection_params(self.repro_box.currentText(), self.trial_idx, params) + + self.trial_change(self.trial_idx + 1) + + @pyqtSlot() + def save_threshold_parameters(self): + self.controller.save_parameters() class PlotCanvas(FigureCanvas): @@ -189,6 +224,7 @@ class PlotCanvas(FigureCanvas): QSizePolicy.Expanding) FigureCanvas.updateGeometry(self) + self.current_repro = None self.mouse_button_pressed = False self.mouse_button = "-1" @@ -249,7 +285,6 @@ class PlotCanvas(FigureCanvas): self.axes.set_ylim(new_ylimits) self.draw() - def onclick(self, event): if event.button in (1, 3): @@ -272,7 +307,7 @@ class PlotCanvas(FigureCanvas): self.mouse_button_pressed = False @pyqtSlot() - def plot(self, trial_idx, repro, data_provider: DataProvider, redetection_vars): + def plot(self, trial_idx, repro, data_provider: Controller, redetection_vars): traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro) trace = traces[trial_idx] spiketimes = spiketimes[trial_idx] @@ -289,8 +324,12 @@ class PlotCanvas(FigureCanvas): redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2]) ax.eventplot(redetect, lineoffsets=max(trace) + 2, colors="red") ax.set_title('Trial XYZ') - ax.set_xlim(xlim) - ax.set_ylim(ylim) + + if self.current_repro == repro: + ax.set_xlim(xlim) + ax.set_ylim(ylim) + else: + self.current_repro = repro self.draw() diff --git a/main.py b/main.py index ed03aa8..9760979 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,29 @@ import sys -from PyQt5.QtWidgets import QApplication +import os -from DataProvider import DataProvider +from PyQt5.QtWidgets import QApplication, QFileDialog + +from Controller import Controller from SpikeRedetectGui import SpikeRedetectGui +test_file = "../neuronModel/data/final_sam/2010-11-08-al-invivo-1" + def main(): app = QApplication(sys.argv) - data_provider = DataProvider("../neuronModel/data/final_sam/2010-11-08-al-invivo-1") + data_path = QFileDialog.getExistingDirectory(caption='Select a directory') + if os.path.isdir(data_path): + data_provider = Controller(data_path) + else: + print("Cell loading didn't work!\n Cell folder: {}".format(data_path)) + return + #data_provider = Controller(test_file) ex = SpikeRedetectGui(data_provider) sys.exit(app.exec_()) + + if __name__ == '__main__': main() diff --git a/testing.py b/testing.py index eeb330b..6b74a5d 100644 --- a/testing.py +++ b/testing.py @@ -1,6 +1,6 @@ from DatParser import DatParser -from DataProvider import DataProvider, average_spike_height +from Controller import Controller, average_spike_height import os import numpy as np @@ -18,7 +18,7 @@ def main(): if cell in failure_to_read_sam: continue cell_folder = os.path.join(DATA_FOLDER, cell) - data_provider = DataProvider(cell_folder) + data_provider = Controller(cell_folder) repros = test_getting_repros(data_provider) @@ -42,7 +42,7 @@ def main(): plt.close() -def test_loading_spikes(data_provider: DataProvider, repro): +def test_loading_spikes(data_provider: Controller, repro): return data_provider.parser.get_spiketimes(repro) @@ -50,7 +50,7 @@ def test_loading_traces(data_provider, repro): return data_provider.get_traces(repro) -def test_getting_repros(data_provider: DataProvider): +def test_getting_repros(data_provider: Controller): return data_provider.get_repros()