from nixview.util.enums import PlotterTypes from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget from PyQt5.QtCore import pyqtSignal, Qt import matplotlib try: matplotlib.use('Qt5Agg') except: matplotlib.use("Agg") print("Cannot load Qt5Agg backend") from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar from matplotlib.figure import Figure import nixio as nix import numpy as np import matplotlib.pyplot as plt from matplotlib.widgets import Slider from IPython import embed from nixview.util.file_handler import FileHandler from nixview.util.dataview import DataView from nixview.communicator import communicator def create_label(item): label = "" if hasattr(item, "label"): label += (item.label if item.label is not None else "") if len(label) == 0 and hasattr(item, "name"): label += item.name if hasattr(item, "unit") and item.unit is not None: label += " [%s]" % item.unit return label class MplCanvas(FigureCanvas): view_changed = pyqtSignal() def __init__(self, parent=None, width=5, height=4, dpi=100): figure = Figure(figsize=(width, height), dpi=dpi) self.axis = figure.add_subplot(111) super(MplCanvas, self).__init__(figure) self._figure = figure self._figure.canvas.mpl_connect('figure_enter_event', self.on_enter_figure) self._figure.canvas.mpl_connect('figure_leave_event', self.on_leave_figure) self._figure.canvas.mpl_connect('axes_enter_event', self.on_enter_axes) self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes) self._figure.canvas.mpl_connect('pick_event', self.on_pick) def on_enter_figure(self, event): print('enter_figure', event.canvas.figure) # event.canvas.figure.patch.set_facecolor('red') # event.canvas.draw() def on_leave_figure(self, event): print('leave_figure', event.canvas.figure) # event.canvas.figure.patch.set_facecolor('grey') # event.canvas.draw() def on_enter_axes(self, event): print('enter_axes', event.inaxes) # event.inaxes.patch.set_facecolor('yellow') # event.canvas.draw() def on_leave_axes(self, event): print('leave_axes', event.inaxes) # event.inaxes.patch.set_facecolor('white') # event.canvas.draw() def on_pick(self, event): line = event.artist print(line.label) xdata, ydata = line.get_data() ind = event.ind print('on pick line:', np.array([xdata[ind], ydata[ind]]).T) #def clear(self): # self.clear() #@property #def enter_figure(self): # return self._fig class Plotter(MplCanvas): def __init__(self, file_handler, item, data_view, parent=None) -> None: super().__init__(parent=parent) self._file_handler = file_handler self._item = item self._dataview = data_view def current_view(self): raise NotImplementedError("current_view is not implemented on the current plotter") def is_full_view(self): raise NotImplementedError("is_full_view is not implemented on the current plotter") class EventPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.sc = None self.dim_count = len(data_array.dimensions) """ if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 1: raise ValueError("EventPlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim """ def plot(self, axis=None): if axis is None: self.fig = plt.figure(figsize=[5.5, 2.]) self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if len(self.array.dimensions) == 1: return self.plot_1d() else: return None def plot_1d(self): data = self.array[:] xlabel = create_label(self.array.dimensions[self.xdim]) dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias: ylabel = create_label(self.array) else: ylabel = "" self.sc = self.axis.scatter(data, np.ones(data.shape)) self.axis.set_ylim([0.5, 1.5]) self.axis.set_yticks([1.]) self.axis.set_yticklabels([]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) return self.axis class CategoryPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.bars = [] """ if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 2: raise ValueError("CategoryPlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim """ def plot(self, axis=None): if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if len(self.array.dimensions) == 1: return self.plot_1d() elif len(self.array.dimensions) == 2: return self.plot_2d() else: return None def plot_1d(self): data = self.array[:] dim = self.array.dimensions[self.xdim] categories = None if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) else: return None if categories is None: categories = ["Cat-%i" % i for i in range(len(data))] ylabel = create_label(self.array) if len(categories) == 0: raise ValueError("Cannot plot a bar chart without any labels") self.bars.append(self.axis.bar(range(1, len(categories)+1), data, tick_label=categories)) self.axis.set_ylabel(ylabel) return self.axis def plot_2d(self): data = self.array[:] if self.xdim == 1: data = data.T categories = None dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) if len(categories) == 0: categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])] dim = self.array.dimensions[1-self.xdim] series_names = [] if dim.dimension_type == nix.DimensionType.Set: series_names = list(dim.labels) if len(series_names) == 0: series_names = ["Series-%i" % i for i in range(data.shape[1-self.xdim])] bar_width = 1/data.shape[1] * 0.75 for i in range(data.shape[1]): x_values = np.arange(data.shape[0]) + i * bar_width self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0]) self.axis.set_xticks(np.arange(data.shape[0]) + data.shape[1] * bar_width/2) self.axis.set_xticklabels(categories) self.axis.legend(self.bars, series_names, loc=1) return self.axis class ImagePlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.image = None def plot(self, axis=None): dim_count = len(self.array.dimensions) if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if dim_count == 2: return self.plot_2d() elif dim_count == 3: return self.plot_3d() else: return None def plot_2d(self): data = self.array[:] x = self.array.dimensions[0].axis(data.shape[0]) y = self.array.dimensions[1].axis(data.shape[1]) xlabel = create_label(self.array.dimensions[0]) ylabel = create_label(self.array.dimensions[1]) self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.set return self.axis def plot_3d(self): if self.array.shape[2] > 3: print("cannot plot 3d data with more than 3 channels " "in the third dim") return None return self.plot_2d() class LinePlotter(Plotter): def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): super().__init__(file_handler, item, data_view, parent) #self.canvas = FigureCanvas(self.figure) #self.toolbar = NavigationToolbar(self.canvas, self) self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id) self._min_x = None self._max_x = None self._min_x = None self.min_x = None self.lines = [] self.dim_count = len(self._dataview.full_shape) if xdim == -1: self.xdim = item.best_xdim elif xdim > 2: raise ValueError("LinePlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim self.axis.callbacks.connect('xlim_changed', self.on_xlims_change) self.axis.callbacks.connect('ylim_changed', self.on_ylims_change) def on_xlims_change(self, event_ax): print("updated xlims: ", event_ax.get_xlim()) def on_ylims_change(self, event_ax): print("updated ylims: ", event_ax.get_ylim()) def current_view(self): cv = [] return cv @property def is_full_view(self): xlims = self.axis.get_xlim() full = self._min_x == xlims[0] and self._max_x == xlims[-1] full = full and self._dataview.fully_loaded return full def on_zoom_in(self, new_position): print("plotter ZOOM In!", new_position) def on_zoom_out(self, new_position): print("plotter ZOOM out!", new_position) def plot(self, maxpoints=100000): self.maxpoints = maxpoints if self.dim_count > 2: return if self.dim_count == 1: self.plot_array_1d() else: self.plot_array_2d() def __draw(self, start, end): if self.dim_count == 1: self.__draw_1d(start, end) else: self.__draw_2d(start, end) def _set_xlims(self, data_xmin, data_xmax): if self._min_x is None or data_xmin < self._min_x: self._min_x = data_xmin if self._max_x is None or data_xmax > self._max_x: self._max_x = data_xmax def __draw_1d(self, start, end): if start < 0: start = 0 if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] y = self._dataview._buffer[int(start):int(end)] x = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y), start) self._set_xlims(x[0], x[-1]) if len(self.lines) == 0: l, = self.axis.plot(x, y, label=self._item.name, picker=5) self.lines.append(l) else: self.lines[0].set_ydata(y) self.lines[0].set_xdata(x) self.axis.set_xlim([x[0], x[-1]]) def __draw_2d(self, start, end): if start < 0: start = 0 if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start) line_count = self._dataview.current_shape[1 - self.xdim] line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0) self._set_xlims(x[0], x[-1]) for i, l in enumerate(line_labels): if (self.xdim == 0): y = self._dataview._buffer[int(start):int(end), i] else: y = self._dataview._buffer[i, int(start):int(end)] if len(self.lines) <= i: ll, = self.axis.plot(x, y, label=l, picker=5) self.lines.append(ll) else: self.lines[i].set_ydata(y) self.lines[i].set_xdata(x) self.axis.set_xlim([x[0], x[-1]]) def plot_array_1d(self): self.__draw_1d(0, self.maxpoints) xlabel = create_label(self.dimensions[self.xdim]) ylabel = create_label(self._item) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.view_changed.emit() def plot_array_2d(self): self.__draw_2d(0, self.maxpoints) xlabel = create_label(self.dimensions[self.xdim]) ylabel = create_label(self._item) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.legend(loc=1) self.view_changed.emit() class PlotContainer(QWidget): def __init__(self, parent=None) -> None: super().__init__(parent=parent) self.setLayout(QVBoxLayout()) self.plotter = None def set_plotter(self, plotter): if not self.layout().isEmpty(): self.layout().removeWidget(self.plotter) self.layout().addWidget(plotter) self.plotter = plotter class PlotScreen(QWidget): close_signal = pyqtSignal() def __init__(self, parent=None) -> None: super().__init__(parent=parent) self._file_handler = FileHandler() self.setLayout(QVBoxLayout()) self._container = PlotContainer(self) self.layout().addWidget(self._container) close_btn = QPushButton("close") close_btn.clicked.connect(self.on_close) self._create_plot_controls() self.layout().addWidget(close_btn) self._data_view = None self.zoom_position = 0 def _create_plot_controls(self): plot_controls = QGroupBox() plot_controls.setFlat(True) plot_controls.setLayout(QHBoxLayout()) plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self._zoom_slider = QSlider(Qt.Horizontal) self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) self._zoom_slider.setFixedWidth(120) self._zoom_slider.setFixedHeight(20) self._zoom_slider.setTickPosition(QSlider.TicksBelow) self._zoom_slider.setSliderPosition(50) self._zoom_slider.setMinimum(0) self._zoom_slider.setMaximum(100) self._zoom_slider.setTickInterval(25) self._zoom_slider.valueChanged.connect(self.on_zoom) self._pan_slider = QSlider(Qt.Horizontal) self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self._pan_slider.setFixedHeight(20) self._pan_slider.setTickPosition(QSlider.TicksBelow) self._pan_slider.setSliderPosition(0) self._pan_slider.setMinimum(0) self._pan_slider.setMaximum(100) self._pan_slider.setTickInterval(25) self._pan_slider.valueChanged.connect(self.on_pan) pl = QLabel("horiz. pos.:") pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) pl.setMaximumWidth(150) zl = QLabel("zoom:") zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) zl.setMaximumWidth(75) plot_controls.layout().addWidget(pl) plot_controls.layout().addWidget(self._pan_slider) plot_controls.layout().addWidget(zl) plot_controls.layout().addWidget(QLabel("+")) plot_controls.layout().addWidget(self._zoom_slider) plot_controls.layout().addWidget(QLabel("-")) self.layout().addWidget(plot_controls) def on_close(self): self.close_signal.emit() def on_zoom(self, new_position): if self.zoom_position < new_position: self.plotter.on_zoom_out(new_position) else: self.plotter.on_zoom_out(new_position) self.zoom_position = new_position def on_pan(self, new_position): print("pan", new_position) def on_view_changed(self): print("view changed!") print(self.plotter.current_view()) if self.plotter.is_full_view: self._zoom_slider.setSliderPosition(100) self._pan_slider.setEnabled(not self.plotter.is_full_view) def plot(self, item): try: self._data_view = DataView(item, self._file_handler) except ValueError as e: communicator.plot_error.emit("error in plotscreen.plot %s" % e) return if self._data_view is None: return #while not self._data_view.fully_loaded: # self._data_view.request_more() # TODO this is just a test, needs to be removed if item.suggested_plotter == PlotterTypes.LinePlotter: self.plotter = LinePlotter(self._file_handler, item, self._data_view) self.plotter.view_changed.connect(self.on_view_changed) self._container.set_plotter(self.plotter) self.plotter.plot(maxpoints=10000) self._zoom_slider.setSliderPosition(100) self.zoom_position = 100