from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\ QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\ QDoubleSpinBox, QComboBox, QSpinBox, QCheckBox, QFileDialog from PyQt5.QtCore import pyqtSlot import numpy as np from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.figure import Figure from Controller import Controller from redetector import detect_spiketimes class SpikeRedetectGui(QWidget): def __init__(self, data_path=None): super().__init__() if data_path is None: data_path = QFileDialog.getExistingDirectory(self, caption='Select a cell directory:') self.controller = Controller(data_path) self.title = 'Spike Redetection' self.left = 10 self.top = 10 self.width = 1500 self.height = 800 self.trial_idx = 0 self.initUI() def initUI(self): self.setWindowTitle(self.title) self.setGeometry(self.left, self.top, self.width, self.height) # Middle: middle = QHBoxLayout() # Canvas Area for matplotlib figure plot_area = QFrame() plot_area_layout = QVBoxLayout() self.canvas = PlotCanvas(self) self.canvas.move(0, 0) plot_area_layout.addWidget(self.canvas) # plot area buttons plot_area_buttons = QFrame() plot_area_buttons_layout = QHBoxLayout() plot_area_buttons.setLayout(plot_area_buttons_layout) plot_area_layout.addWidget(plot_area_buttons) button = QPushButton('previous stimulus value', self) button.setToolTip('A nice button!') button.clicked.connect(lambda: self.threshold_spinbox.setValue(1)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) button = QPushButton('previous trial', self) button.setToolTip('Another nice button!') button.clicked.connect(lambda: self.trial_change(self.trial_idx - 1)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) button = QPushButton('next trial', self) button.setToolTip('Even more nice buttons!') button.clicked.connect(lambda: self.trial_change(self.trial_idx + 1)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) button = QPushButton('next stimulus value', self) button.setToolTip('But this is not a nice button!') button.clicked.connect(lambda: self.threshold_spinbox.setValue(4)) button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) plot_area_buttons_layout.addWidget(button) plot_area.setLayout(plot_area_layout) middle.addWidget(plot_area) middle.addWidget(QVLine()) # Side (options) panel panel = QFrame() panel.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) panel.setMaximumWidth(200) panel_layout = QVBoxLayout() repro_label = QLabel("Repro:") panel_layout.addWidget(repro_label) self.repro_box = QComboBox() self.repro_box.addItems(self.controller.get_repros()) self.repro_box.currentTextChanged.connect(self.repro_change) panel_layout.addWidget(self.repro_box) stim_val_label = QLabel("Stimulus value:") panel_layout.addWidget(stim_val_label) self.stim_val_box = QComboBox() panel_layout.addWidget(self.stim_val_box) self.grouping_checkbox = QCheckBox() self.grouping_checkbox.setText("Group by Stimulus Value:") panel_layout.addWidget(self.grouping_checkbox) trial_label = QLabel("Trial:") panel_layout.addWidget(trial_label) self.trial_spinbox = QSpinBox(self) self.trial_spinbox.setValue(self.trial_idx) self.trial_spinbox.setSingleStep(1) self.trial_spinbox.valueChanged.connect(lambda: self.trial_change(self.trial_spinbox.value())) panel_layout.addWidget(self.trial_spinbox) # self.status_label = QLabel("Done x/15 Stimulus Values") # panel_layout.addWidget(self.status_label) filler = QFill() panel_layout.addWidget(filler) threshold_label = QLabel("Threshold:") panel_layout.addWidget(threshold_label) self.threshold_spinbox = QDoubleSpinBox(self) self.threshold_spinbox.setValue(self.controller.base_detection_parameters[0]) self.threshold_spinbox.setSingleStep(0.5) self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed()) panel_layout.addWidget(self.threshold_spinbox) window_label = QLabel("Min window size:") panel_layout.addWidget(window_label) self.window_spinbox = QSpinBox(self) self.window_spinbox.setMaximum(2**21) self.window_spinbox.setValue(self.controller.base_detection_parameters[1]) self.window_spinbox.setSingleStep(500) self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed()) panel_layout.addWidget(self.window_spinbox) step_label = QLabel("step size:") panel_layout.addWidget(step_label) self.step_spinbox = QSpinBox(self) self.step_spinbox.setMaximum(2 ** 21) self.step_spinbox.setValue(self.controller.base_detection_parameters[2]) self.step_spinbox.setSingleStep(200) self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed()) panel_layout.addWidget(self.step_spinbox) button = QPushButton('Accept!', self) button.setToolTip('Accept the threshold for current stimulus value') button.clicked.connect(self.accept_redetection) panel_layout.addWidget(button) button = QPushButton('Save thresholds!', self) button.setToolTip('Save accepted thresholds!') button.clicked.connect(self.save_threshold_parameters) panel_layout.addWidget(button) button = QPushButton('Save redetected spikes!', self) button.setToolTip('Save redetected spikes with the accepted thresholds.') button.clicked.connect(self.save_redetected_spikes) panel_layout.addWidget(button) panel.setLayout(panel_layout) middle.addWidget(panel) self.setLayout(middle) # init the first repro: self.repro_change() self.show() @pyqtSlot() def redetection_changed(self): redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection) @pyqtSlot() def trial_change(self, new_trial_idx): if new_trial_idx == self.trial_idx: return traces, spiketimes, recording_times = self.controller.get_traces_with_spiketimes(self.repro_box.currentText()) if not 0 <= new_trial_idx < len(spiketimes): print("trial_change():Out of range trial index!") return self.trial_idx = new_trial_idx if new_trial_idx != self.trial_spinbox.value(): self.trial_spinbox.setValue(new_trial_idx) redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection) @pyqtSlot() def repro_change(self): repro = self.repro_box.currentText() redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection) # reset trials and stim values self.trial_change(0) self.stim_val_box.clear() for val in self.controller.get_stim_values(repro): self.stim_val_box.addItem(str(val)) @pyqtSlot() def accept_redetection(self): params = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value()) if not self.grouping_checkbox.isChecked(): self.controller.set_redetection_params(self.repro_box.currentText(), self.trial_idx, params) self.trial_change(self.trial_idx + 1) else: # TODO set params for all trials of the current stim value # also choose next trial idx that has no accepted params yet and update stim_val_box pass @pyqtSlot() def save_threshold_parameters(self): data_path = QFileDialog.getExistingDirectory(self, directory=self.controller.data_path, caption='Select a folder to save the parameters:') if data_path is None: return self.controller.save_parameters(data_path) @pyqtSlot() def save_redetected_spikes(self): data_path = QFileDialog.getExistingDirectory(self, directory=self.controller.data_path, caption='Select a folder to save the redetected spikes:') if data_path is None: return self.controller.save_redetected_spikes(data_path) class PlotCanvas(FigureCanvas): def __init__(self, parent=None, dpi=100): self.fig = Figure(dpi=dpi) self.axes = self.fig.add_subplot(111) FigureCanvas.__init__(self, self.fig) self.setParent(parent) FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding) FigureCanvas.updateGeometry(self) self.current_repro = None self.mouse_button_pressed = False self.mouse_button = "-1" self.mouse_position_start = (-1, -1) self.start_limits = ((0, 0), (0, 0)) self.fig.canvas.mpl_connect('button_press_event', self.onclick) self.fig.canvas.mpl_connect('button_release_event', self.release) self.fig.canvas.mpl_connect('motion_notify_event', self.moved) self.fig.canvas.mpl_connect('scroll_event', self.scrolled) # XY position 0,0 at the bottom left! # XY positions change when rescaling the window size # zoom depend on percentage of total x/y length def scrolled(self, event): # event.button = "up" / "down" print('Scrolled: ', event) def moved(self, event): if not self.mouse_button_pressed or event.xdata is None: return if self.mouse_button == 1: new_x = event.x new_y = event.y # mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start)) mouse_data_new = self.axes.transData.inverted().transform((new_x, new_y)) diff_x = mouse_data_new[0] - self.mouse_data_start[0] diff_y = mouse_data_new[1] - self.mouse_data_start[1] self.axes.set_xlim(self.axes.get_xlim() - diff_x) self.axes.set_ylim(self.axes.get_ylim() - diff_y) self.draw() elif self.mouse_button == 3: zoom_strength = 50 # pixels to half or double the axis limits length_x = self.start_limits[0][1] - self.start_limits[0][0] length_y = self.start_limits[1][1] - self.start_limits[1][0] diff_x = self.mouse_position_start[0] - event.x diff_y = self.mouse_position_start[1] - event.y factor_x = 2**(diff_x / zoom_strength) factor_y = 2**(diff_y / zoom_strength) new_length_x = length_x * factor_x new_length_y = length_y * factor_y new_xlimits = (self.start_limits[0][0] - 0.5*new_length_x + 0.5*length_x, self.start_limits[0][1] + 0.5*new_length_x - 0.5*length_x) new_ylimits = (self.start_limits[1][0] - 0.5*new_length_y + 0.5*length_y, self.start_limits[1][1] + 0.5*new_length_y - 0.5*length_y) self.axes.set_xlim(new_xlimits) self.axes.set_ylim(new_ylimits) self.draw() def onclick(self, event): if event.button in (1, 3): self.mouse_button_pressed = True self.mouse_button = event.button self.mouse_position_start = (event.x, event.y) self.mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start)) # print("Figure:", self.mouse_position_start, "Data:", self.mouse_data_start) xlim = self.axes.get_xlim() ylim = self.axes.get_ylim() self.start_limits = (xlim, ylim) # print('Cliched: %s click: button=%d, x=%d, y=%d, xdata=%s, ydata=%s' % # ('double' if event.dblclick else 'single', event.button, # event.x, event.y, event.xdata, event.ydata)) def release(self, event): if event.button in (1, 3): self.mouse_button_pressed = False @pyqtSlot() def plot(self, trial_idx, repro, data_provider: Controller, redetection_vars): traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro) trace = traces[trial_idx] spiketimes = spiketimes[trial_idx] recording_times = recording_times sampling_interval = data_provider.sampling_interval ax = self.axes xlim = self.axes.get_xlim() ylim = self.axes.get_ylim() ax.clear() time = np.arange(len(trace)) * sampling_interval - recording_times[0] ax.plot(time, trace) ax.eventplot(spiketimes, lineoffsets=max(trace) + 2, colors="black") redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2]) ax.eventplot(redetect, lineoffsets=max(trace) + 1, colors="red") ax.plot(redetect, trace[np.round(np.array((redetect + recording_times[0]) / sampling_interval)).astype(int)], 'o', color='red') ax.set_title('Trial') if self.current_repro == repro: ax.set_xlim(xlim) ax.set_ylim(ylim) else: self.current_repro = repro self.draw() class QHLine(QFrame): def __init__(self): super(QHLine, self).__init__() self.setFrameShape(QFrame.HLine) self.setFrameShadow(QFrame.Sunken) class QVLine(QFrame): def __init__(self): super(QVLine, self).__init__() self.setFrameShape(QFrame.VLine) self.setFrameShadow(QFrame.Sunken) class QFill(QFrame): def __init__(self, maxw=int(2**24)-1, maxh=int(2**24)-1, minw=0, minh=0): super(QFill, self).__init__() self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) self.setMaximumSize(maxw, maxh) self.setMinimumSize(minw, minh)