diff --git a/DataProvider.py b/DataProvider.py index 8aa02d0..f3c0fb1 100644 --- a/DataProvider.py +++ b/DataProvider.py @@ -27,8 +27,9 @@ class DataProvider: else: self.repros_todo = self.repros - self.base_Detection_parameters = base_detection_params + self.base_detection_parameters = base_detection_params self.thresholds = {} + self.sampling_interval = self.parser.get_sampling_interval() self.sorting = {} self.recording_times = {} @@ -46,9 +47,9 @@ class DataProvider: if repro in self.sorting.keys(): traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1]) v1_traces = traces[1] - spiketimes = self.get_unsorted_spiketimes(repro) + spiketimes, metadata = self.get_unsorted_spiketimes(repro) - sorted_spiketimes = np.array(spiketimes)[self.sorting[repro]] + sorted_spiketimes = np.array(spiketimes, dtype=object)[self.sorting[repro]] return v1_traces, sorted_spiketimes, self.recording_times[repro] if repro == "FICurve": @@ -56,13 +57,13 @@ class DataProvider: before = abs(recording_times[0]) after = recording_times[3] [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after) - spiketimes = self.get_unsorted_spiketimes(repro) + (spiketimes, metadata) = self.get_unsorted_spiketimes(repro) else: before = 0 after = 0 [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro) - spiketimes = self.get_unsorted_spiketimes(repro) + (spiketimes, metadata) = self.get_unsorted_spiketimes(repro) if len(v1_traces) != len(spiketimes): warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!" diff --git a/SpikeRedetectGui.py b/SpikeRedetectGui.py index fa65eb6..afb8393 100644 --- a/SpikeRedetectGui.py +++ b/SpikeRedetectGui.py @@ -9,6 +9,7 @@ import numpy as np from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.figure import Figure from DataProvider import DataProvider +from redetector import detect_spiketimes class SpikeRedetectGui(QWidget): @@ -21,6 +22,9 @@ class SpikeRedetectGui(QWidget): self.top = 10 self.width = 640 self.height = 400 + + self.trial_idx = 0 + self.initUI() def initUI(self): @@ -34,9 +38,9 @@ class SpikeRedetectGui(QWidget): plot_area = QFrame() plot_area_layout = QVBoxLayout() - m = PlotCanvas(self) - m.move(0, 0) - plot_area_layout.addWidget(m) + self.canvas = PlotCanvas(self) + self.canvas.move(0, 0) + plot_area_layout.addWidget(self.canvas) # plot area buttons plot_area_buttons = QFrame() @@ -46,25 +50,25 @@ class SpikeRedetectGui(QWidget): button = QPushButton('Button1', self) button.setToolTip('A nice button!') - button.clicked.connect(lambda: threshold_spinbox.setValue(1)) + button.clicked.connect(lambda: self.threshold_spinbox.setValue(1)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) button = QPushButton('Button2', self) button.setToolTip('Another nice button!') - button.clicked.connect(lambda: threshold_spinbox.setValue(2)) + button.clicked.connect(lambda: self.threshold_spinbox.setValue(2)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) button = QPushButton('Button3', self) button.setToolTip('Even more nice buttons!') - button.clicked.connect(lambda: threshold_spinbox.setValue(3)) + button.clicked.connect(lambda: self.threshold_spinbox.setValue(3)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) button = QPushButton('Button4', self) button.setToolTip('Even more nice buttons!') - button.clicked.connect(lambda: threshold_spinbox.setValue(4)) + button.clicked.connect(lambda: self.threshold_spinbox.setValue(4)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) @@ -87,21 +91,21 @@ class SpikeRedetectGui(QWidget): repro_label = QLabel("Repro:") panel_layout.addWidget(repro_label) self.repro_box = QComboBox() - self.repro_box.currentTextChanged.connect(self.repro_change) + for repro in self.data_provider.get_repros(): self.repro_box.addItem(repro) - + self.repro_box.currentTextChanged.connect(self.repro_change) panel_layout.addWidget(self.repro_box) panel_layout.addWidget(stim_val_label) panel_layout.addWidget(self.stim_val_box) trial_label = QLabel("Trial:") panel_layout.addWidget(trial_label) - threshold_spinbox = QSpinBox(self) - threshold_spinbox.setValue(1) - threshold_spinbox.setSingleStep(1) - threshold_spinbox.valueChanged.connect() - panel_layout.addWidget(threshold_spinbox) + trial_spinbox = QSpinBox(self) + trial_spinbox.setValue(1) + trial_spinbox.setSingleStep(1) + trial_spinbox.valueChanged.connect(lambda: self.trial_change(trial_spinbox.value())) + panel_layout.addWidget(trial_spinbox) filler = QFill(minh=200) panel_layout.addWidget(filler) @@ -113,11 +117,29 @@ class SpikeRedetectGui(QWidget): threshold_label = QLabel("Threshold:") panel_layout.addWidget(threshold_label) - threshold_spinbox = QDoubleSpinBox(self) - threshold_spinbox.setValue(1) - threshold_spinbox.setSingleStep(0.5) - threshold_spinbox.valueChanged.connect(lambda: m.plot(threshold_spinbox.value())) - panel_layout.addWidget(threshold_spinbox) + self.threshold_spinbox = QDoubleSpinBox(self) + self.threshold_spinbox.setValue(self.data_provider.base_detection_parameters[0]) + self.threshold_spinbox.setSingleStep(0.5) + self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed()) + panel_layout.addWidget(self.threshold_spinbox) + + window_label = QLabel("Min window size:") + panel_layout.addWidget(window_label) + self.window_spinbox = QSpinBox(self) + self.window_spinbox.setMaximum(2**21) + self.window_spinbox.setValue(self.data_provider.base_detection_parameters[1]) + self.window_spinbox.setSingleStep(500) + self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed()) + panel_layout.addWidget(self.window_spinbox) + + step_label = QLabel("step size:") + panel_layout.addWidget(step_label) + self.step_spinbox = QSpinBox(self) + self.step_spinbox.setMaximum(2 ** 21) + self.step_spinbox.setValue(self.data_provider.base_detection_parameters[2]) + self.step_spinbox.setSingleStep(200) + self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed()) + panel_layout.addWidget(self.step_spinbox) button = QPushButton('Accept!', self) button.setToolTip('Accept the threshold for current stimulus value') @@ -129,6 +151,18 @@ class SpikeRedetectGui(QWidget): self.setLayout(middle) self.show() + @pyqtSlot() + def redetection_changed(self): + redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) + self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) + + @pyqtSlot() + def trial_change(self, new_trial_idx): + # TODO test if in range of trials! + self.trial_idx = new_trial_idx + redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) + self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) + @pyqtSlot() def repro_change(self): repro = self.repro_box.currentText() @@ -137,30 +171,126 @@ class SpikeRedetectGui(QWidget): for val in self.data_provider.get_stim_values(repro): self.stim_val_box.addItem(str(val)) + redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) + self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection) + class PlotCanvas(FigureCanvas): def __init__(self, parent=None, dpi=100): - fig = Figure(dpi=dpi) - self.axes = fig.add_subplot(111) + self.fig = Figure(dpi=dpi) + self.axes = self.fig.add_subplot(111) - FigureCanvas.__init__(self, fig) + FigureCanvas.__init__(self, self.fig) self.setParent(parent) FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding) FigureCanvas.updateGeometry(self) - self.plot() + + self.mouse_button_pressed = False + self.mouse_button = "-1" + + self.mouse_position_start = (-1, -1) + self.start_limits = ((0, 0), (0, 0)) + + self.fig.canvas.mpl_connect('button_press_event', self.onclick) + self.fig.canvas.mpl_connect('button_release_event', self.release) + self.fig.canvas.mpl_connect('motion_notify_event', self.moved) + self.fig.canvas.mpl_connect('scroll_event', self.scrolled) + + # XY position 0,0 at the bottom left! + # XY positions change when rescaling the window size + # zoom depend on percentage of total x/y length + + def scrolled(self, event): + # event.button = "up" / "down" + print('Scrolled: ', event) + + def moved(self, event): + if not self.mouse_button_pressed or event.xdata is None: + return + + if self.mouse_button == 1: + new_x = event.x + new_y = event.y + + # mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start)) + mouse_data_new = self.axes.transData.inverted().transform((new_x, new_y)) + + diff_x = mouse_data_new[0] - self.mouse_data_start[0] + diff_y = mouse_data_new[1] - self.mouse_data_start[1] + + self.axes.set_xlim(self.axes.get_xlim() - diff_x) + self.axes.set_ylim(self.axes.get_ylim() - diff_y) + + self.draw() + + elif self.mouse_button == 3: + + zoom_strength = 50 # pixels to half or double the axis limits + length_x = self.start_limits[0][1] - self.start_limits[0][0] + length_y = self.start_limits[1][1] - self.start_limits[1][0] + + diff_x = self.mouse_position_start[0] - event.x + diff_y = self.mouse_position_start[1] - event.y + + factor_x = 2**(diff_x / zoom_strength) + factor_y = 2**(diff_y / zoom_strength) + + new_length_x = length_x * factor_x + new_length_y = length_y * factor_y + + new_xlimits = (self.start_limits[0][0] - 0.5*new_length_x + 0.5*length_x, self.start_limits[0][1] + 0.5*new_length_x - 0.5*length_x) + new_ylimits = (self.start_limits[1][0] - 0.5*new_length_y + 0.5*length_y, self.start_limits[1][1] + 0.5*new_length_y - 0.5*length_y) + + self.axes.set_xlim(new_xlimits) + self.axes.set_ylim(new_ylimits) + self.draw() + + + def onclick(self, event): + + if event.button in (1, 3): + self.mouse_button_pressed = True + self.mouse_button = event.button + self.mouse_position_start = (event.x, event.y) + + self.mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start)) + # print("Figure:", self.mouse_position_start, "Data:", self.mouse_data_start) + xlim = self.axes.get_xlim() + ylim = self.axes.get_ylim() + self.start_limits = (xlim, ylim) + + # print('Cliched: %s click: button=%d, x=%d, y=%d, xdata=%s, ydata=%s' % + # ('double' if event.dblclick else 'single', event.button, + # event.x, event.y, event.xdata, event.ydata)) + + def release(self, event): + if event.button in (1, 3): + self.mouse_button_pressed = False @pyqtSlot() - def plot(self, mean=1): - x = np.arange(0, 1, 0.0001) - data = np.sin(x*np.pi*2*mean) + def plot(self, trial_idx, repro, data_provider: DataProvider, redetection_vars): + traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro) + trace = traces[trial_idx] + spiketimes = spiketimes[trial_idx] + recording_times = recording_times + sampling_interval = data_provider.sampling_interval + ax = self.axes + xlim = self.axes.get_xlim() + ylim = self.axes.get_ylim() ax.clear() - ax.plot(x, data, 'r-') - ax.set_title('Sinus Example') + time = np.arange(len(trace)) * sampling_interval - recording_times[0] + ax.plot(time, trace) + ax.eventplot(spiketimes, lineoffsets=max(trace) + 1, colors="black") + redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2]) + ax.eventplot(redetect, lineoffsets=max(trace) + 2, colors="red") + ax.set_title('Trial XYZ') + ax.set_xlim(xlim) + ax.set_ylim(ylim) self.draw()