add dragging and zoom with the mouse buttons for the plot
This commit is contained in:
parent
38ba40ce2f
commit
b0dfa22ef4
@ -27,8 +27,9 @@ class DataProvider:
|
|||||||
else:
|
else:
|
||||||
self.repros_todo = self.repros
|
self.repros_todo = self.repros
|
||||||
|
|
||||||
self.base_Detection_parameters = base_detection_params
|
self.base_detection_parameters = base_detection_params
|
||||||
self.thresholds = {}
|
self.thresholds = {}
|
||||||
|
self.sampling_interval = self.parser.get_sampling_interval()
|
||||||
|
|
||||||
self.sorting = {}
|
self.sorting = {}
|
||||||
self.recording_times = {}
|
self.recording_times = {}
|
||||||
@ -46,9 +47,9 @@ class DataProvider:
|
|||||||
if repro in self.sorting.keys():
|
if repro in self.sorting.keys():
|
||||||
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
|
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
|
||||||
v1_traces = traces[1]
|
v1_traces = traces[1]
|
||||||
spiketimes = self.get_unsorted_spiketimes(repro)
|
spiketimes, metadata = self.get_unsorted_spiketimes(repro)
|
||||||
|
|
||||||
sorted_spiketimes = np.array(spiketimes)[self.sorting[repro]]
|
sorted_spiketimes = np.array(spiketimes, dtype=object)[self.sorting[repro]]
|
||||||
return v1_traces, sorted_spiketimes, self.recording_times[repro]
|
return v1_traces, sorted_spiketimes, self.recording_times[repro]
|
||||||
|
|
||||||
if repro == "FICurve":
|
if repro == "FICurve":
|
||||||
@ -56,13 +57,13 @@ class DataProvider:
|
|||||||
before = abs(recording_times[0])
|
before = abs(recording_times[0])
|
||||||
after = recording_times[3]
|
after = recording_times[3]
|
||||||
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
|
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
|
||||||
spiketimes = self.get_unsorted_spiketimes(repro)
|
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
before = 0
|
before = 0
|
||||||
after = 0
|
after = 0
|
||||||
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
|
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
|
||||||
spiketimes = self.get_unsorted_spiketimes(repro)
|
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
|
||||||
|
|
||||||
if len(v1_traces) != len(spiketimes):
|
if len(v1_traces) != len(spiketimes):
|
||||||
warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!"
|
warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!"
|
||||||
|
@ -9,6 +9,7 @@ import numpy as np
|
|||||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from DataProvider import DataProvider
|
from DataProvider import DataProvider
|
||||||
|
from redetector import detect_spiketimes
|
||||||
|
|
||||||
|
|
||||||
class SpikeRedetectGui(QWidget):
|
class SpikeRedetectGui(QWidget):
|
||||||
@ -21,6 +22,9 @@ class SpikeRedetectGui(QWidget):
|
|||||||
self.top = 10
|
self.top = 10
|
||||||
self.width = 640
|
self.width = 640
|
||||||
self.height = 400
|
self.height = 400
|
||||||
|
|
||||||
|
self.trial_idx = 0
|
||||||
|
|
||||||
self.initUI()
|
self.initUI()
|
||||||
|
|
||||||
def initUI(self):
|
def initUI(self):
|
||||||
@ -34,9 +38,9 @@ class SpikeRedetectGui(QWidget):
|
|||||||
|
|
||||||
plot_area = QFrame()
|
plot_area = QFrame()
|
||||||
plot_area_layout = QVBoxLayout()
|
plot_area_layout = QVBoxLayout()
|
||||||
m = PlotCanvas(self)
|
self.canvas = PlotCanvas(self)
|
||||||
m.move(0, 0)
|
self.canvas.move(0, 0)
|
||||||
plot_area_layout.addWidget(m)
|
plot_area_layout.addWidget(self.canvas)
|
||||||
|
|
||||||
# plot area buttons
|
# plot area buttons
|
||||||
plot_area_buttons = QFrame()
|
plot_area_buttons = QFrame()
|
||||||
@ -46,25 +50,25 @@ class SpikeRedetectGui(QWidget):
|
|||||||
|
|
||||||
button = QPushButton('Button1', self)
|
button = QPushButton('Button1', self)
|
||||||
button.setToolTip('A nice button!')
|
button.setToolTip('A nice button!')
|
||||||
button.clicked.connect(lambda: threshold_spinbox.setValue(1))
|
button.clicked.connect(lambda: self.threshold_spinbox.setValue(1))
|
||||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||||
plot_area_buttons_layout.addWidget(button)
|
plot_area_buttons_layout.addWidget(button)
|
||||||
|
|
||||||
button = QPushButton('Button2', self)
|
button = QPushButton('Button2', self)
|
||||||
button.setToolTip('Another nice button!')
|
button.setToolTip('Another nice button!')
|
||||||
button.clicked.connect(lambda: threshold_spinbox.setValue(2))
|
button.clicked.connect(lambda: self.threshold_spinbox.setValue(2))
|
||||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||||
plot_area_buttons_layout.addWidget(button)
|
plot_area_buttons_layout.addWidget(button)
|
||||||
|
|
||||||
button = QPushButton('Button3', self)
|
button = QPushButton('Button3', self)
|
||||||
button.setToolTip('Even more nice buttons!')
|
button.setToolTip('Even more nice buttons!')
|
||||||
button.clicked.connect(lambda: threshold_spinbox.setValue(3))
|
button.clicked.connect(lambda: self.threshold_spinbox.setValue(3))
|
||||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||||
plot_area_buttons_layout.addWidget(button)
|
plot_area_buttons_layout.addWidget(button)
|
||||||
|
|
||||||
button = QPushButton('Button4', self)
|
button = QPushButton('Button4', self)
|
||||||
button.setToolTip('Even more nice buttons!')
|
button.setToolTip('Even more nice buttons!')
|
||||||
button.clicked.connect(lambda: threshold_spinbox.setValue(4))
|
button.clicked.connect(lambda: self.threshold_spinbox.setValue(4))
|
||||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||||
plot_area_buttons_layout.addWidget(button)
|
plot_area_buttons_layout.addWidget(button)
|
||||||
|
|
||||||
@ -87,21 +91,21 @@ class SpikeRedetectGui(QWidget):
|
|||||||
repro_label = QLabel("Repro:")
|
repro_label = QLabel("Repro:")
|
||||||
panel_layout.addWidget(repro_label)
|
panel_layout.addWidget(repro_label)
|
||||||
self.repro_box = QComboBox()
|
self.repro_box = QComboBox()
|
||||||
self.repro_box.currentTextChanged.connect(self.repro_change)
|
|
||||||
for repro in self.data_provider.get_repros():
|
for repro in self.data_provider.get_repros():
|
||||||
self.repro_box.addItem(repro)
|
self.repro_box.addItem(repro)
|
||||||
|
self.repro_box.currentTextChanged.connect(self.repro_change)
|
||||||
panel_layout.addWidget(self.repro_box)
|
panel_layout.addWidget(self.repro_box)
|
||||||
panel_layout.addWidget(stim_val_label)
|
panel_layout.addWidget(stim_val_label)
|
||||||
panel_layout.addWidget(self.stim_val_box)
|
panel_layout.addWidget(self.stim_val_box)
|
||||||
|
|
||||||
trial_label = QLabel("Trial:")
|
trial_label = QLabel("Trial:")
|
||||||
panel_layout.addWidget(trial_label)
|
panel_layout.addWidget(trial_label)
|
||||||
threshold_spinbox = QSpinBox(self)
|
trial_spinbox = QSpinBox(self)
|
||||||
threshold_spinbox.setValue(1)
|
trial_spinbox.setValue(1)
|
||||||
threshold_spinbox.setSingleStep(1)
|
trial_spinbox.setSingleStep(1)
|
||||||
threshold_spinbox.valueChanged.connect()
|
trial_spinbox.valueChanged.connect(lambda: self.trial_change(trial_spinbox.value()))
|
||||||
panel_layout.addWidget(threshold_spinbox)
|
panel_layout.addWidget(trial_spinbox)
|
||||||
|
|
||||||
filler = QFill(minh=200)
|
filler = QFill(minh=200)
|
||||||
panel_layout.addWidget(filler)
|
panel_layout.addWidget(filler)
|
||||||
@ -113,11 +117,29 @@ class SpikeRedetectGui(QWidget):
|
|||||||
|
|
||||||
threshold_label = QLabel("Threshold:")
|
threshold_label = QLabel("Threshold:")
|
||||||
panel_layout.addWidget(threshold_label)
|
panel_layout.addWidget(threshold_label)
|
||||||
threshold_spinbox = QDoubleSpinBox(self)
|
self.threshold_spinbox = QDoubleSpinBox(self)
|
||||||
threshold_spinbox.setValue(1)
|
self.threshold_spinbox.setValue(self.data_provider.base_detection_parameters[0])
|
||||||
threshold_spinbox.setSingleStep(0.5)
|
self.threshold_spinbox.setSingleStep(0.5)
|
||||||
threshold_spinbox.valueChanged.connect(lambda: m.plot(threshold_spinbox.value()))
|
self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed())
|
||||||
panel_layout.addWidget(threshold_spinbox)
|
panel_layout.addWidget(self.threshold_spinbox)
|
||||||
|
|
||||||
|
window_label = QLabel("Min window size:")
|
||||||
|
panel_layout.addWidget(window_label)
|
||||||
|
self.window_spinbox = QSpinBox(self)
|
||||||
|
self.window_spinbox.setMaximum(2**21)
|
||||||
|
self.window_spinbox.setValue(self.data_provider.base_detection_parameters[1])
|
||||||
|
self.window_spinbox.setSingleStep(500)
|
||||||
|
self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed())
|
||||||
|
panel_layout.addWidget(self.window_spinbox)
|
||||||
|
|
||||||
|
step_label = QLabel("step size:")
|
||||||
|
panel_layout.addWidget(step_label)
|
||||||
|
self.step_spinbox = QSpinBox(self)
|
||||||
|
self.step_spinbox.setMaximum(2 ** 21)
|
||||||
|
self.step_spinbox.setValue(self.data_provider.base_detection_parameters[2])
|
||||||
|
self.step_spinbox.setSingleStep(200)
|
||||||
|
self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed())
|
||||||
|
panel_layout.addWidget(self.step_spinbox)
|
||||||
|
|
||||||
button = QPushButton('Accept!', self)
|
button = QPushButton('Accept!', self)
|
||||||
button.setToolTip('Accept the threshold for current stimulus value')
|
button.setToolTip('Accept the threshold for current stimulus value')
|
||||||
@ -129,6 +151,18 @@ class SpikeRedetectGui(QWidget):
|
|||||||
self.setLayout(middle)
|
self.setLayout(middle)
|
||||||
self.show()
|
self.show()
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def redetection_changed(self):
|
||||||
|
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
|
||||||
|
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def trial_change(self, new_trial_idx):
|
||||||
|
# TODO test if in range of trials!
|
||||||
|
self.trial_idx = new_trial_idx
|
||||||
|
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
|
||||||
|
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
|
||||||
|
|
||||||
@pyqtSlot()
|
@pyqtSlot()
|
||||||
def repro_change(self):
|
def repro_change(self):
|
||||||
repro = self.repro_box.currentText()
|
repro = self.repro_box.currentText()
|
||||||
@ -137,30 +171,126 @@ class SpikeRedetectGui(QWidget):
|
|||||||
for val in self.data_provider.get_stim_values(repro):
|
for val in self.data_provider.get_stim_values(repro):
|
||||||
self.stim_val_box.addItem(str(val))
|
self.stim_val_box.addItem(str(val))
|
||||||
|
|
||||||
|
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
|
||||||
|
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
|
||||||
|
|
||||||
|
|
||||||
class PlotCanvas(FigureCanvas):
|
class PlotCanvas(FigureCanvas):
|
||||||
|
|
||||||
def __init__(self, parent=None, dpi=100):
|
def __init__(self, parent=None, dpi=100):
|
||||||
fig = Figure(dpi=dpi)
|
self.fig = Figure(dpi=dpi)
|
||||||
self.axes = fig.add_subplot(111)
|
self.axes = self.fig.add_subplot(111)
|
||||||
|
|
||||||
FigureCanvas.__init__(self, fig)
|
FigureCanvas.__init__(self, self.fig)
|
||||||
self.setParent(parent)
|
self.setParent(parent)
|
||||||
|
|
||||||
FigureCanvas.setSizePolicy(self,
|
FigureCanvas.setSizePolicy(self,
|
||||||
QSizePolicy.Expanding,
|
QSizePolicy.Expanding,
|
||||||
QSizePolicy.Expanding)
|
QSizePolicy.Expanding)
|
||||||
FigureCanvas.updateGeometry(self)
|
FigureCanvas.updateGeometry(self)
|
||||||
self.plot()
|
|
||||||
|
self.mouse_button_pressed = False
|
||||||
|
self.mouse_button = "-1"
|
||||||
|
|
||||||
|
self.mouse_position_start = (-1, -1)
|
||||||
|
self.start_limits = ((0, 0), (0, 0))
|
||||||
|
|
||||||
|
self.fig.canvas.mpl_connect('button_press_event', self.onclick)
|
||||||
|
self.fig.canvas.mpl_connect('button_release_event', self.release)
|
||||||
|
self.fig.canvas.mpl_connect('motion_notify_event', self.moved)
|
||||||
|
self.fig.canvas.mpl_connect('scroll_event', self.scrolled)
|
||||||
|
|
||||||
|
# XY position 0,0 at the bottom left!
|
||||||
|
# XY positions change when rescaling the window size
|
||||||
|
# zoom depend on percentage of total x/y length
|
||||||
|
|
||||||
|
def scrolled(self, event):
|
||||||
|
# event.button = "up" / "down"
|
||||||
|
print('Scrolled: ', event)
|
||||||
|
|
||||||
|
def moved(self, event):
|
||||||
|
if not self.mouse_button_pressed or event.xdata is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.mouse_button == 1:
|
||||||
|
new_x = event.x
|
||||||
|
new_y = event.y
|
||||||
|
|
||||||
|
# mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start))
|
||||||
|
mouse_data_new = self.axes.transData.inverted().transform((new_x, new_y))
|
||||||
|
|
||||||
|
diff_x = mouse_data_new[0] - self.mouse_data_start[0]
|
||||||
|
diff_y = mouse_data_new[1] - self.mouse_data_start[1]
|
||||||
|
|
||||||
|
self.axes.set_xlim(self.axes.get_xlim() - diff_x)
|
||||||
|
self.axes.set_ylim(self.axes.get_ylim() - diff_y)
|
||||||
|
|
||||||
|
self.draw()
|
||||||
|
|
||||||
|
elif self.mouse_button == 3:
|
||||||
|
|
||||||
|
zoom_strength = 50 # pixels to half or double the axis limits
|
||||||
|
length_x = self.start_limits[0][1] - self.start_limits[0][0]
|
||||||
|
length_y = self.start_limits[1][1] - self.start_limits[1][0]
|
||||||
|
|
||||||
|
diff_x = self.mouse_position_start[0] - event.x
|
||||||
|
diff_y = self.mouse_position_start[1] - event.y
|
||||||
|
|
||||||
|
factor_x = 2**(diff_x / zoom_strength)
|
||||||
|
factor_y = 2**(diff_y / zoom_strength)
|
||||||
|
|
||||||
|
new_length_x = length_x * factor_x
|
||||||
|
new_length_y = length_y * factor_y
|
||||||
|
|
||||||
|
new_xlimits = (self.start_limits[0][0] - 0.5*new_length_x + 0.5*length_x, self.start_limits[0][1] + 0.5*new_length_x - 0.5*length_x)
|
||||||
|
new_ylimits = (self.start_limits[1][0] - 0.5*new_length_y + 0.5*length_y, self.start_limits[1][1] + 0.5*new_length_y - 0.5*length_y)
|
||||||
|
|
||||||
|
self.axes.set_xlim(new_xlimits)
|
||||||
|
self.axes.set_ylim(new_ylimits)
|
||||||
|
self.draw()
|
||||||
|
|
||||||
|
|
||||||
|
def onclick(self, event):
|
||||||
|
|
||||||
|
if event.button in (1, 3):
|
||||||
|
self.mouse_button_pressed = True
|
||||||
|
self.mouse_button = event.button
|
||||||
|
self.mouse_position_start = (event.x, event.y)
|
||||||
|
|
||||||
|
self.mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start))
|
||||||
|
# print("Figure:", self.mouse_position_start, "Data:", self.mouse_data_start)
|
||||||
|
xlim = self.axes.get_xlim()
|
||||||
|
ylim = self.axes.get_ylim()
|
||||||
|
self.start_limits = (xlim, ylim)
|
||||||
|
|
||||||
|
# print('Cliched: %s click: button=%d, x=%d, y=%d, xdata=%s, ydata=%s' %
|
||||||
|
# ('double' if event.dblclick else 'single', event.button,
|
||||||
|
# event.x, event.y, event.xdata, event.ydata))
|
||||||
|
|
||||||
|
def release(self, event):
|
||||||
|
if event.button in (1, 3):
|
||||||
|
self.mouse_button_pressed = False
|
||||||
|
|
||||||
@pyqtSlot()
|
@pyqtSlot()
|
||||||
def plot(self, mean=1):
|
def plot(self, trial_idx, repro, data_provider: DataProvider, redetection_vars):
|
||||||
x = np.arange(0, 1, 0.0001)
|
traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro)
|
||||||
data = np.sin(x*np.pi*2*mean)
|
trace = traces[trial_idx]
|
||||||
|
spiketimes = spiketimes[trial_idx]
|
||||||
|
recording_times = recording_times
|
||||||
|
sampling_interval = data_provider.sampling_interval
|
||||||
|
|
||||||
ax = self.axes
|
ax = self.axes
|
||||||
|
xlim = self.axes.get_xlim()
|
||||||
|
ylim = self.axes.get_ylim()
|
||||||
ax.clear()
|
ax.clear()
|
||||||
ax.plot(x, data, 'r-')
|
time = np.arange(len(trace)) * sampling_interval - recording_times[0]
|
||||||
ax.set_title('Sinus Example')
|
ax.plot(time, trace)
|
||||||
|
ax.eventplot(spiketimes, lineoffsets=max(trace) + 1, colors="black")
|
||||||
|
redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2])
|
||||||
|
ax.eventplot(redetect, lineoffsets=max(trace) + 2, colors="red")
|
||||||
|
ax.set_title('Trial XYZ')
|
||||||
|
ax.set_xlim(xlim)
|
||||||
|
ax.set_ylim(ylim)
|
||||||
self.draw()
|
self.draw()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user