from nixview.util.enums import PlotterTypes from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget from PyQt5.QtCore import pyqtSignal, Qt import matplotlib matplotlib.use('Qt5Agg') from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar from matplotlib.figure import Figure import nixio as nix import numpy as np try: import matplotlib.pyplot as plt except ImportError as e: print("cannot import matplotlib, headless mode?", e) from matplotlib.widgets import Slider from IPython import embed from nixview.util.file_handler import FileHandler from nixview.util.dataview import DataView from nixview.communicator import communicator def create_label(item): label = "" if hasattr(item, "label"): label += (item.label if item.label is not None else "") if len(label) == 0 and hasattr(item, "name"): label += item.name if hasattr(item, "unit") and item.unit is not None: label += " [%s]" % item.unit return label class MplCanvas(FigureCanvas): view_changed = pyqtSignal() def __init__(self, parent=None, width=5, height=4, dpi=100): figure = Figure(figsize=(width, height), dpi=dpi) self.axis = figure.add_subplot(111) super(MplCanvas, self).__init__(figure) self._figure = figure self._figure.canvas.mpl_connect('figure_enter_event', self.on_enter_figure) self._figure.canvas.mpl_connect('figure_leave_event', self.on_leave_figure) self._figure.canvas.mpl_connect('axes_enter_event', self.on_enter_axes) self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes) self._figure.canvas.mpl_connect('pick_event', self.on_pick) def on_enter_figure(self, event): print('enter_figure', event.canvas.figure) # event.canvas.figure.patch.set_facecolor('red') # event.canvas.draw() def on_leave_figure(self, event): print('leave_figure', event.canvas.figure) # event.canvas.figure.patch.set_facecolor('grey') # event.canvas.draw() def on_enter_axes(self, event): print('enter_axes', event.inaxes) # event.inaxes.patch.set_facecolor('yellow') # event.canvas.draw() def on_leave_axes(self, event): print('leave_axes', event.inaxes) # event.inaxes.patch.set_facecolor('white') # event.canvas.draw() def on_pick(self, event): line = event.artist print(line.label) xdata, ydata = line.get_data() ind = event.ind print('on pick line:', np.array([xdata[ind], ydata[ind]]).T) #def clear(self): # self.clear() #@property #def enter_figure(self): # return self._fig class Plotter(MplCanvas): def __init__(self, file_handler, item, data_view, parent=None) -> None: super().__init__(parent=parent) self._file_handler = file_handler self._item = item self._dataview = data_view def current_view(self): raise NotImplementedError("current_view is not implemented on the current plotter") def is_full_view(self): raise NotImplementedError("is_full_view is not implemented on the current plotter") class EventPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.sc = None self.dim_count = len(data_array.dimensions) """ if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 1: raise ValueError("EventPlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim """ def plot(self, axis=None): if axis is None: self.fig = plt.figure(figsize=[5.5, 2.]) self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if len(self.array.dimensions) == 1: return self.plot_1d() else: return None def plot_1d(self): data = self.array[:] xlabel = create_label(self.array.dimensions[self.xdim]) dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias: ylabel = create_label(self.array) else: ylabel = "" self.sc = self.axis.scatter(data, np.ones(data.shape)) self.axis.set_ylim([0.5, 1.5]) self.axis.set_yticks([1.]) self.axis.set_yticklabels([]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) return self.axis class CategoryPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.bars = [] """ if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 2: raise ValueError("CategoryPlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim """ def plot(self, axis=None): if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if len(self.array.dimensions) == 1: return self.plot_1d() elif len(self.array.dimensions) == 2: return self.plot_2d() else: return None def plot_1d(self): data = self.array[:] dim = self.array.dimensions[self.xdim] categories = None if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) else: return None if categories is None: categories = ["Cat-%i" % i for i in range(len(data))] ylabel = create_label(self.array) if len(categories) == 0: raise ValueError("Cannot plot a bar chart without any labels") self.bars.append(self.axis.bar(range(1, len(categories)+1), data, tick_label=categories)) self.axis.set_ylabel(ylabel) return self.axis def plot_2d(self): data = self.array[:] if self.xdim == 1: data = data.T categories = None dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) if len(categories) == 0: categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])] dim = self.array.dimensions[1-self.xdim] series_names = [] if dim.dimension_type == nix.DimensionType.Set: series_names = list(dim.labels) if len(series_names) == 0: series_names = ["Series-%i" % i for i in range(data.shape[1-self.xdim])] bar_width = 1/data.shape[1] * 0.75 for i in range(data.shape[1]): x_values = np.arange(data.shape[0]) + i * bar_width self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0]) self.axis.set_xticks(np.arange(data.shape[0]) + data.shape[1] * bar_width/2) self.axis.set_xticklabels(categories) self.axis.legend(self.bars, series_names, loc=1) return self.axis class ImagePlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.image = None def plot(self, axis=None): dim_count = len(self.array.dimensions) if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if dim_count == 2: return self.plot_2d() elif dim_count == 3: return self.plot_3d() else: return None def plot_2d(self): data = self.array[:] x = self.array.dimensions[0].axis(data.shape[0]) y = self.array.dimensions[1].axis(data.shape[1]) xlabel = create_label(self.array.dimensions[0]) ylabel = create_label(self.array.dimensions[1]) self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.set return self.axis def plot_3d(self): if self.array.shape[2] > 3: print("cannot plot 3d data with more than 3 channels " "in the third dim") return None return self.plot_2d() class LinePlotter(Plotter): def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): super().__init__(file_handler, item, data_view, parent) #self.canvas = FigureCanvas(self.figure) #self.toolbar = NavigationToolbar(self.canvas, self) self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id) self._min_x = None self._max_x = None self._min_x = None self.min_x = None self.lines = [] self.dim_count = len(self._dataview.full_shape) if xdim == -1: self.xdim = item.best_xdim elif xdim > 2: raise ValueError("LinePlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim self.axis.callbacks.connect('xlim_changed', self.on_xlims_change) self.axis.callbacks.connect('ylim_changed', self.on_ylims_change) def on_xlims_change(self, event_ax): print("updated xlims: ", event_ax.get_xlim()) def on_ylims_change(self, event_ax): print("updated ylims: ", event_ax.get_ylim()) def current_view(self): cv = [] return cv @property def is_full_view(self): xlims = self.axis.get_xlim() full = self._min_x == xlims[0] and self._max_x == xlims[-1] full = full and self._dataview.fully_loaded return full def on_zoom_in(self, new_position): print("plotter ZOOM In!", new_position) def on_zoom_out(self, new_position): print("plotter ZOOM out!", new_position) def plot(self, maxpoints=100000): self.maxpoints = maxpoints if self.dim_count > 2: return if self.dim_count == 1: self.plot_array_1d() else: self.plot_array_2d() def __draw(self, start, end): if self.dim_count == 1: self.__draw_1d(start, end) else: self.__draw_2d(start, end) def _set_xlims(self, data_xmin, data_xmax): if self._min_x is None or data_xmin < self._min_x: self._min_x = data_xmin if self._max_x is None or data_xmax > self._max_x: self._max_x = data_xmax def __draw_1d(self, start, end): if start < 0: start = 0 if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] y = self._dataview._buffer[int(start):int(end)] x = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y), start) self._set_xlims(x[0], x[-1]) if len(self.lines) == 0: l, = self.axis.plot(x, y, label=self._item.name, picker=5) self.lines.append(l) else: self.lines[0].set_ydata(y) self.lines[0].set_xdata(x) self.axis.set_xlim([x[0], x[-1]]) def __draw_2d(self, start, end): if start < 0: start = 0 if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start) line_count = self._dataview.current_shape[1 - self.xdim] line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0) self._set_xlims(x[0], x[-1]) for i, l in enumerate(line_labels): if (self.xdim == 0): y = self._dataview._buffer[int(start):int(end), i] else: y = self._dataview._buffer[i, int(start):int(end)] if len(self.lines) <= i: ll, = self.axis.plot(x, y, label=l, picker=5) self.lines.append(ll) else: self.lines[i].set_ydata(y) self.lines[i].set_xdata(x) self.axis.set_xlim([x[0], x[-1]]) def plot_array_1d(self): self.__draw_1d(0, self.maxpoints) xlabel = create_label(self.dimensions[self.xdim]) ylabel = create_label(self._item) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.view_changed.emit() def plot_array_2d(self): self.__draw_2d(0, self.maxpoints) xlabel = create_label(self.dimensions[self.xdim]) ylabel = create_label(self._item) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.legend(loc=1) self.view_changed.emit() class PlotContainer(QWidget): def __init__(self, parent=None) -> None: super().__init__(parent=parent) self.setLayout(QVBoxLayout()) self.plotter = None def set_plotter(self, plotter): if not self.layout().isEmpty(): self.layout().removeWidget(self.plotter) self.layout().addWidget(plotter) self.plotter = plotter class PlotScreen(QWidget): close_signal = pyqtSignal() def __init__(self, parent=None) -> None: super().__init__(parent=parent) self._file_handler = FileHandler() self.setLayout(QVBoxLayout()) self._container = PlotContainer(self) self.layout().addWidget(self._container) close_btn = QPushButton("close") close_btn.clicked.connect(self.on_close) self._create_plot_controls() self.layout().addWidget(close_btn) self._data_view = None self.zoom_position = 0 def _create_plot_controls(self): plot_controls = QGroupBox() plot_controls.setFlat(True) plot_controls.setLayout(QHBoxLayout()) plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self._zoom_slider = QSlider(Qt.Horizontal) self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) self._zoom_slider.setFixedWidth(120) self._zoom_slider.setFixedHeight(20) self._zoom_slider.setTickPosition(QSlider.TicksBelow) self._zoom_slider.setSliderPosition(50) self._zoom_slider.setMinimum(0) self._zoom_slider.setMaximum(100) self._zoom_slider.setTickInterval(25) self._zoom_slider.valueChanged.connect(self.on_zoom) self._pan_slider = QSlider(Qt.Horizontal) self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self._pan_slider.setFixedHeight(20) self._pan_slider.setTickPosition(QSlider.TicksBelow) self._pan_slider.setSliderPosition(0) self._pan_slider.setMinimum(0) self._pan_slider.setMaximum(100) self._pan_slider.setTickInterval(25) self._pan_slider.valueChanged.connect(self.on_pan) pl = QLabel("horiz. pos.:") pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) pl.setMaximumWidth(150) zl = QLabel("zoom:") zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) zl.setMaximumWidth(75) plot_controls.layout().addWidget(pl) plot_controls.layout().addWidget(self._pan_slider) plot_controls.layout().addWidget(zl) plot_controls.layout().addWidget(QLabel("+")) plot_controls.layout().addWidget(self._zoom_slider) plot_controls.layout().addWidget(QLabel("-")) self.layout().addWidget(plot_controls) def on_close(self): self.close_signal.emit() def on_zoom(self, new_position): if self.zoom_position < new_position: self.plotter.on_zoom_out(new_position) else: self.plotter.on_zoom_out(new_position) self.zoom_position = new_position def on_pan(self, new_position): print("pan", new_position) def on_view_changed(self): print("view changed!") print(self.plotter.current_view()) if self.plotter.is_full_view: self._zoom_slider.setSliderPosition(100) self._pan_slider.setEnabled(not self.plotter.is_full_view) def plot(self, item): try: self._data_view = DataView(item, self._file_handler) except ValueError as e: communicator.plot_error.emit("error in plotscreen.plot %s" % e) return if self._data_view is None: return #while not self._data_view.fully_loaded: # self._data_view.request_more() # TODO this is just a test, needs to be removed if item.suggested_plotter == PlotterTypes.LinePlotter: self.plotter = LinePlotter(self._file_handler, item, self._data_view) self.plotter.view_changed.connect(self.on_view_changed) self._container.set_plotter(self.plotter) self.plotter.plot(maxpoints=10000) self._zoom_slider.setSliderPosition(100) self.zoom_position = 100