spikeRedetector/SpikeRedetectGui.py
2021-07-02 15:06:04 +02:00

379 lines
14 KiB
Python

from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
QDoubleSpinBox, QComboBox, QSpinBox, QCheckBox, QFileDialog
from PyQt5.QtCore import pyqtSlot
import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from Controller import Controller
from redetector import detect_spiketimes
class SpikeRedetectGui(QWidget):
def __init__(self, data_path=None):
super().__init__()
if data_path is None:
data_path = QFileDialog.getExistingDirectory(self, caption='Select a cell directory:')
self.controller = Controller(data_path)
self.title = 'Spike Redetection'
self.left = 10
self.top = 10
self.width = 1500
self.height = 800
self.trial_idx = 0
self.initUI()
def initUI(self):
self.setWindowTitle(self.title)
self.setGeometry(self.left, self.top, self.width, self.height)
# Middle:
middle = QHBoxLayout()
# Canvas Area for matplotlib figure
plot_area = QFrame()
plot_area_layout = QVBoxLayout()
self.canvas = PlotCanvas(self)
self.canvas.move(0, 0)
plot_area_layout.addWidget(self.canvas)
# plot area buttons
plot_area_buttons = QFrame()
plot_area_buttons_layout = QHBoxLayout()
plot_area_buttons.setLayout(plot_area_buttons_layout)
plot_area_layout.addWidget(plot_area_buttons)
button = QPushButton('previous stimulus value', self)
button.setToolTip('A nice button!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('previous trial', self)
button.setToolTip('Another nice button!')
button.clicked.connect(lambda: self.trial_change(self.trial_idx - 1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('next trial', self)
button.setToolTip('Even more nice buttons!')
button.clicked.connect(lambda: self.trial_change(self.trial_idx + 1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('next stimulus value', self)
button.setToolTip('But this is not a nice button!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(4))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
plot_area.setLayout(plot_area_layout)
middle.addWidget(plot_area)
middle.addWidget(QVLine())
# Side (options) panel
panel = QFrame()
panel.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
panel.setMaximumWidth(200)
panel_layout = QVBoxLayout()
repro_label = QLabel("Repro:")
panel_layout.addWidget(repro_label)
self.repro_box = QComboBox()
self.repro_box.addItems(self.controller.get_repros())
self.repro_box.currentTextChanged.connect(self.repro_change)
panel_layout.addWidget(self.repro_box)
stim_val_label = QLabel("Stimulus value:")
panel_layout.addWidget(stim_val_label)
self.stim_val_box = QComboBox()
panel_layout.addWidget(self.stim_val_box)
self.grouping_checkbox = QCheckBox()
self.grouping_checkbox.setText("Group by Stimulus Value:")
panel_layout.addWidget(self.grouping_checkbox)
trial_label = QLabel("Trial:")
panel_layout.addWidget(trial_label)
self.trial_spinbox = QSpinBox(self)
self.trial_spinbox.setValue(self.trial_idx)
self.trial_spinbox.setSingleStep(1)
self.trial_spinbox.valueChanged.connect(lambda: self.trial_change(self.trial_spinbox.value()))
panel_layout.addWidget(self.trial_spinbox)
# self.status_label = QLabel("Done x/15 Stimulus Values")
# panel_layout.addWidget(self.status_label)
filler = QFill()
panel_layout.addWidget(filler)
threshold_label = QLabel("Threshold:")
panel_layout.addWidget(threshold_label)
self.threshold_spinbox = QDoubleSpinBox(self)
self.threshold_spinbox.setValue(self.controller.base_detection_parameters[0])
self.threshold_spinbox.setSingleStep(0.5)
self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.threshold_spinbox)
window_label = QLabel("Min window size:")
panel_layout.addWidget(window_label)
self.window_spinbox = QSpinBox(self)
self.window_spinbox.setMaximum(2**21)
self.window_spinbox.setValue(self.controller.base_detection_parameters[1])
self.window_spinbox.setSingleStep(500)
self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.window_spinbox)
step_label = QLabel("step size:")
panel_layout.addWidget(step_label)
self.step_spinbox = QSpinBox(self)
self.step_spinbox.setMaximum(2 ** 21)
self.step_spinbox.setValue(self.controller.base_detection_parameters[2])
self.step_spinbox.setSingleStep(200)
self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.step_spinbox)
button = QPushButton('Accept!', self)
button.setToolTip('Accept the threshold for current stimulus value')
button.clicked.connect(self.accept_redetection)
panel_layout.addWidget(button)
button = QPushButton('Save thresholds!', self)
button.setToolTip('Save accepted thresholds!')
button.clicked.connect(self.save_threshold_parameters)
panel_layout.addWidget(button)
button = QPushButton('Save redetected spikes!', self)
button.setToolTip('Save redetected spikes with the accepted thresholds.')
button.clicked.connect(self.save_redetected_spikes)
panel_layout.addWidget(button)
panel.setLayout(panel_layout)
middle.addWidget(panel)
self.setLayout(middle)
# init the first repro:
self.repro_change()
self.show()
@pyqtSlot()
def redetection_changed(self):
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
@pyqtSlot()
def trial_change(self, new_trial_idx):
if new_trial_idx == self.trial_idx:
return
traces, spiketimes, recording_times = self.controller.get_traces_with_spiketimes(self.repro_box.currentText())
if not 0 <= new_trial_idx < len(spiketimes):
print("trial_change():Out of range trial index!")
return
self.trial_idx = new_trial_idx
if new_trial_idx != self.trial_spinbox.value():
self.trial_spinbox.setValue(new_trial_idx)
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
@pyqtSlot()
def repro_change(self):
repro = self.repro_box.currentText()
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
# reset trials and stim values
self.trial_change(0)
self.stim_val_box.clear()
for val in self.controller.get_stim_values(repro):
self.stim_val_box.addItem(str(val))
@pyqtSlot()
def accept_redetection(self):
params = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
if not self.grouping_checkbox.isChecked():
self.controller.set_redetection_params(self.repro_box.currentText(), self.trial_idx, params)
self.trial_change(self.trial_idx + 1)
else:
# TODO set params for all trials of the current stim value
# also choose next trial idx that has no accepted params yet and update stim_val_box
pass
@pyqtSlot()
def save_threshold_parameters(self):
data_path = QFileDialog.getExistingDirectory(self, directory=self.controller.data_path, caption='Select a folder to save the parameters:')
if data_path is None:
return
self.controller.save_parameters(data_path)
@pyqtSlot()
def save_redetected_spikes(self):
data_path = QFileDialog.getExistingDirectory(self, directory=self.controller.data_path, caption='Select a folder to save the redetected spikes:')
if data_path is None:
return
self.controller.save_redetected_spikes(data_path)
class PlotCanvas(FigureCanvas):
def __init__(self, parent=None, dpi=100):
self.fig = Figure(dpi=dpi)
self.axes = self.fig.add_subplot(111)
FigureCanvas.__init__(self, self.fig)
self.setParent(parent)
FigureCanvas.setSizePolicy(self,
QSizePolicy.Expanding,
QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
self.current_repro = None
self.mouse_button_pressed = False
self.mouse_button = "-1"
self.mouse_position_start = (-1, -1)
self.start_limits = ((0, 0), (0, 0))
self.fig.canvas.mpl_connect('button_press_event', self.onclick)
self.fig.canvas.mpl_connect('button_release_event', self.release)
self.fig.canvas.mpl_connect('motion_notify_event', self.moved)
self.fig.canvas.mpl_connect('scroll_event', self.scrolled)
# XY position 0,0 at the bottom left!
# XY positions change when rescaling the window size
# zoom depend on percentage of total x/y length
def scrolled(self, event):
# event.button = "up" / "down"
print('Scrolled: ', event)
def moved(self, event):
if not self.mouse_button_pressed or event.xdata is None:
return
if self.mouse_button == 1:
new_x = event.x
new_y = event.y
# mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start))
mouse_data_new = self.axes.transData.inverted().transform((new_x, new_y))
diff_x = mouse_data_new[0] - self.mouse_data_start[0]
diff_y = mouse_data_new[1] - self.mouse_data_start[1]
self.axes.set_xlim(self.axes.get_xlim() - diff_x)
self.axes.set_ylim(self.axes.get_ylim() - diff_y)
self.draw()
elif self.mouse_button == 3:
zoom_strength = 50 # pixels to half or double the axis limits
length_x = self.start_limits[0][1] - self.start_limits[0][0]
length_y = self.start_limits[1][1] - self.start_limits[1][0]
diff_x = self.mouse_position_start[0] - event.x
diff_y = self.mouse_position_start[1] - event.y
factor_x = 2**(diff_x / zoom_strength)
factor_y = 2**(diff_y / zoom_strength)
new_length_x = length_x * factor_x
new_length_y = length_y * factor_y
new_xlimits = (self.start_limits[0][0] - 0.5*new_length_x + 0.5*length_x, self.start_limits[0][1] + 0.5*new_length_x - 0.5*length_x)
new_ylimits = (self.start_limits[1][0] - 0.5*new_length_y + 0.5*length_y, self.start_limits[1][1] + 0.5*new_length_y - 0.5*length_y)
self.axes.set_xlim(new_xlimits)
self.axes.set_ylim(new_ylimits)
self.draw()
def onclick(self, event):
if event.button in (1, 3):
self.mouse_button_pressed = True
self.mouse_button = event.button
self.mouse_position_start = (event.x, event.y)
self.mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start))
# print("Figure:", self.mouse_position_start, "Data:", self.mouse_data_start)
xlim = self.axes.get_xlim()
ylim = self.axes.get_ylim()
self.start_limits = (xlim, ylim)
# print('Cliched: %s click: button=%d, x=%d, y=%d, xdata=%s, ydata=%s' %
# ('double' if event.dblclick else 'single', event.button,
# event.x, event.y, event.xdata, event.ydata))
def release(self, event):
if event.button in (1, 3):
self.mouse_button_pressed = False
@pyqtSlot()
def plot(self, trial_idx, repro, data_provider: Controller, redetection_vars):
traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro)
trace = traces[trial_idx]
spiketimes = spiketimes[trial_idx]
recording_times = recording_times
sampling_interval = data_provider.sampling_interval
ax = self.axes
xlim = self.axes.get_xlim()
ylim = self.axes.get_ylim()
ax.clear()
time = np.arange(len(trace)) * sampling_interval - recording_times[0]
ax.plot(time, trace)
ax.eventplot(spiketimes, lineoffsets=max(trace) + 2, colors="black")
redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2])
ax.eventplot(redetect, lineoffsets=max(trace) + 1, colors="red")
ax.plot(redetect, trace[np.round(np.array((redetect + recording_times[0]) / sampling_interval)).astype(int)], 'o', color='red')
ax.set_title('Trial')
if self.current_repro == repro:
ax.set_xlim(xlim)
ax.set_ylim(ylim)
else:
self.current_repro = repro
self.draw()
class QHLine(QFrame):
def __init__(self):
super(QHLine, self).__init__()
self.setFrameShape(QFrame.HLine)
self.setFrameShadow(QFrame.Sunken)
class QVLine(QFrame):
def __init__(self):
super(QVLine, self).__init__()
self.setFrameShape(QFrame.VLine)
self.setFrameShadow(QFrame.Sunken)
class QFill(QFrame):
def __init__(self, maxw=int(2**24)-1, maxh=int(2**24)-1, minw=0, minh=0):
super(QFill, self).__init__()
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.setMaximumSize(maxw, maxh)
self.setMinimumSize(minw, minh)