diff --git a/nixview/ui/plotscreen.py b/nixview/ui/plotscreen.py index f661a5c..0e06dd7 100644 --- a/nixview/ui/plotscreen.py +++ b/nixview/ui/plotscreen.py @@ -3,7 +3,7 @@ from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePo from PyQt5.QtCore import pyqtSignal, Qt import matplotlib matplotlib.use('Qt5Agg') -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar from matplotlib.figure import Figure import nixio as nix import numpy as np @@ -27,7 +27,7 @@ def create_label(item): return label -class MplCanvas(FigureCanvasQTAgg): +class MplCanvas(FigureCanvas): view_changed = pyqtSignal() def __init__(self, parent=None, width=5, height=4, dpi=100): @@ -35,12 +35,45 @@ class MplCanvas(FigureCanvasQTAgg): self.axis = figure.add_subplot(111) super(MplCanvas, self).__init__(figure) self._figure = figure + self._figure.canvas.mpl_connect('figure_enter_event', self.on_enter_figure) + self._figure.canvas.mpl_connect('figure_leave_event', self.on_leave_figure) + self._figure.canvas.mpl_connect('axes_enter_event', self.on_enter_axes) + self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes) + self._figure.canvas.mpl_connect('pick_event', self.on_pick) + + + def on_enter_figure(self, event): + print('enter_figure', event.canvas.figure) + # event.canvas.figure.patch.set_facecolor('red') + # event.canvas.draw() + + def on_leave_figure(self, event): + print('leave_figure', event.canvas.figure) + # event.canvas.figure.patch.set_facecolor('grey') + # event.canvas.draw() + + def on_enter_axes(self, event): + print('enter_axes', event.inaxes) + # event.inaxes.patch.set_facecolor('yellow') + # event.canvas.draw() + + def on_leave_axes(self, event): + print('leave_axes', event.inaxes) + # event.inaxes.patch.set_facecolor('white') + # event.canvas.draw() + + def on_pick(self, event): + line = event.artist + print(line.label) + xdata, ydata = line.get_data() + ind = event.ind + print('on pick line:', np.array([xdata[ind], ydata[ind]]).T) #def clear(self): # self.clear() #@property - #def figure(self): + #def enter_figure(self): # return self._fig @@ -51,10 +84,12 @@ class Plotter(MplCanvas): self._item = item self._dataview = data_view - def current_view(self): raise NotImplementedError("current_view is not implemented on the current plotter") - + + def is_full_view(self): + raise NotImplementedError("is_full_view is not implemented on the current plotter") + class EventPlotter(Plotter): @@ -63,6 +98,7 @@ class EventPlotter(Plotter): self.array = data_array self.sc = None self.dim_count = len(data_array.dimensions) + """ if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 1: @@ -70,6 +106,7 @@ class EventPlotter(Plotter): "Cannot plot that kind of data") else: self.xdim = xdim + """ def plot(self, axis=None): if axis is None: @@ -106,6 +143,7 @@ class CategoryPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.bars = [] + """ if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 2: @@ -113,6 +151,7 @@ class CategoryPlotter(Plotter): "Cannot plot that kind of data") else: self.xdim = xdim + """ def plot(self, axis=None): if axis is None: @@ -132,6 +171,7 @@ class CategoryPlotter(Plotter): def plot_1d(self): data = self.array[:] dim = self.array.dimensions[self.xdim] + categories = None if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) else: @@ -150,7 +190,7 @@ class CategoryPlotter(Plotter): data = self.array[:] if self.xdim == 1: data = data.T - + categories = None dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) @@ -158,6 +198,7 @@ class CategoryPlotter(Plotter): categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])] dim = self.array.dimensions[1-self.xdim] + series_names = [] if dim.dimension_type == nix.DimensionType.Set: series_names = list(dim.labels) if len(series_names) == 0: @@ -223,8 +264,14 @@ class LinePlotter(Plotter): def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): super().__init__(file_handler, item, data_view, parent) - + #self.canvas = FigureCanvas(self.figure) + #self.toolbar = NavigationToolbar(self.canvas, self) + self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id) + self._min_x = None + self._max_x = None + self._min_x = None + self.min_x = None self.lines = [] self.dim_count = len(self._dataview.full_shape) if xdim == -1: @@ -234,7 +281,33 @@ class LinePlotter(Plotter): "Cannot plot that kind of data") else: self.xdim = xdim + self.axis.callbacks.connect('xlim_changed', self.on_xlims_change) + self.axis.callbacks.connect('ylim_changed', self.on_ylims_change) + + def on_xlims_change(self, event_ax): + print("updated xlims: ", event_ax.get_xlim()) + def on_ylims_change(self, event_ax): + + print("updated ylims: ", event_ax.get_ylim()) + + def current_view(self): + cv = [] + return cv + + @property + def is_full_view(self): + xlims = self.axis.get_xlim() + full = self._min_x == xlims[0] and self._max_x == xlims[-1] + full = full and self._dataview.fully_loaded + return full + + def on_zoom_in(self, new_position): + print("plotter ZOOM In!", new_position) + + def on_zoom_out(self, new_position): + print("plotter ZOOM out!", new_position) + def plot(self, maxpoints=100000): self.maxpoints = maxpoints if self.dim_count > 2: @@ -244,27 +317,18 @@ class LinePlotter(Plotter): else: self.plot_array_2d() - def __add_slider(self): - steps = self.array.shape[self.xdim] / self.maxpoints - slider_ax = self.fig.add_axes([0.15, 0.025, 0.8, 0.025]) - self.slider = Slider(slider_ax, 'Slider', 1., steps, valinit=1., - valstep=0.25) - self.slider.on_changed(self.__update) - - def __update(self, val): - if len(self.lines) > 0: - minimum = val * self.maxpoints - self.maxpoints - start = minimum if minimum > 0 else 0 - end = val * self.maxpoints - self.__draw(start, end) - self.fig.canvas.draw_idle() - def __draw(self, start, end): if self.dim_count == 1: self.__draw_1d(start, end) else: self.__draw_2d(start, end) + def _set_xlims(self, data_xmin, data_xmax): + if self._min_x is None or data_xmin < self._min_x: + self._min_x = data_xmin + if self._max_x is None or data_xmax > self._max_x: + self._max_x = data_xmax + def __draw_1d(self, start, end): if start < 0: start = 0 @@ -273,9 +337,9 @@ class LinePlotter(Plotter): y = self._dataview._buffer[int(start):int(end)] x = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y), start) - + self._set_xlims(x[0], x[-1]) if len(self.lines) == 0: - l, = self.axis.plot(x, y, label=self._item.name) + l, = self.axis.plot(x, y, label=self._item.name, picker=5) self.lines.append(l) else: self.lines[0].set_ydata(y) @@ -292,6 +356,7 @@ class LinePlotter(Plotter): x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start) line_count = self._dataview.current_shape[1 - self.xdim] line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0) + self._set_xlims(x[0], x[-1]) for i, l in enumerate(line_labels): if (self.xdim == 0): @@ -300,7 +365,7 @@ class LinePlotter(Plotter): y = self._dataview._buffer[i, int(start):int(end)] if len(self.lines) <= i: - ll, = self.axis.plot(x, y, label=l) + ll, = self.axis.plot(x, y, label=l, picker=5) self.lines.append(ll) else: self.lines[i].set_ydata(y) @@ -356,6 +421,8 @@ class PlotScreen(QWidget): self._create_plot_controls() self.layout().addWidget(close_btn) self._data_view = None + + self.zoom_position = 0 def _create_plot_controls(self): plot_controls = QGroupBox() @@ -401,14 +468,23 @@ class PlotScreen(QWidget): self.close_signal.emit() def on_zoom(self, new_position): - print("zoom", new_position) + if self.zoom_position < new_position: + self.plotter.on_zoom_out(new_position) + else: + self.plotter.on_zoom_out(new_position) + self.zoom_position = new_position def on_pan(self, new_position): + print("pan", new_position) def on_view_changed(self): print("view changed!") print(self.plotter.current_view()) + if self.plotter.is_full_view: + self._zoom_slider.setSliderPosition(100) + self._pan_slider.setEnabled(not self.plotter.is_full_view) + def plot(self, item): try: @@ -418,11 +494,13 @@ class PlotScreen(QWidget): return if self._data_view is None: return - while not self._data_view.fully_loaded: - self._data_view.request_more() # TODO this is just a test, needs to be removed + #while not self._data_view.fully_loaded: + # self._data_view.request_more() # TODO this is just a test, needs to be removed if item.suggested_plotter == PlotterTypes.LinePlotter: self.plotter = LinePlotter(self._file_handler, item, self._data_view) self.plotter.view_changed.connect(self.on_view_changed) self._container.set_plotter(self.plotter) - self.plotter.plot() + self.plotter.plot(maxpoints=10000) + self._zoom_slider.setSliderPosition(100) + self.zoom_position = 100