from nixview.data_models.tree_model import PropertyTreeItem from nixview.util import dataview from nixview.util.enums import PlotterTypes from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget from PyQt5.QtCore import QObject, pyqtSignal, Qt import matplotlib matplotlib.use('Qt5Agg') from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg from matplotlib.figure import Figure import nixio as nix import numpy as np import matplotlib.pyplot as plt from matplotlib.widgets import Slider from IPython import embed from nixview.util.file_handler import FileHandler from nixview.util.dataview import DataView from nixview.communicator import communicator def create_label(item): label = "" if hasattr(item, "label"): label += (item.label if item.label is not None else "") if len(label) == 0 and hasattr(item, "name"): label += item.name if hasattr(item, "unit") and item.unit is not None: label += " [%s]" % item.unit return label class MplCanvas(FigureCanvasQTAgg): view_changed = pyqtSignal() def __init__(self, parent=None, width=5, height=4, dpi=100): figure = Figure(figsize=(width, height), dpi=dpi) self.axis = figure.add_subplot(111) super(MplCanvas, self).__init__(figure) self._figure = figure #def clear(self): # self.clear() #@property #def figure(self): # return self._fig class Plotter(MplCanvas): def __init__(self, file_handler, item, parent=None) -> None: super().__init__(parent=parent) self._file_handler = file_handler self._item = item def show(self): plt.show() class EventPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.sc = None self.dim_count = len(data_array.dimensions) if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 1: raise ValueError("EventPlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim def plot(self, axis=None): if axis is None: self.fig = plt.figure(figsize=[5.5, 2.]) self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if len(self.array.dimensions) == 1: return self.plot_1d() else: return None def plot_1d(self): data = self.array[:] xlabel = create_label(self.array.dimensions[self.xdim]) dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias: ylabel = create_label(self.array) else: ylabel = "" self.sc = self.axis.scatter(data, np.ones(data.shape)) self.axis.set_ylim([0.5, 1.5]) self.axis.set_yticks([1.]) self.axis.set_yticklabels([]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) return self.axis class CategoryPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.bars = [] if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 2: raise ValueError("CategoryPlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim def plot(self, axis=None): if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if len(self.array.dimensions) == 1: return self.plot_1d() elif len(self.array.dimensions) == 2: return self.plot_2d() else: return None def plot_1d(self): data = self.array[:] dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) else: return None if categories is None: categories = ["Cat-%i" % i for i in range(len(data))] ylabel = create_label(self.array) if len(categories) == 0: raise ValueError("Cannot plot a bar chart without any labels") self.bars.append(self.axis.bar(range(1, len(categories)+1), data, tick_label=categories)) self.axis.set_ylabel(ylabel) return self.axis def plot_2d(self): data = self.array[:] if self.xdim == 1: data = data.T dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) if len(categories) == 0: categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])] dim = self.array.dimensions[1-self.xdim] if dim.dimension_type == nix.DimensionType.Set: series_names = list(dim.labels) if len(series_names) == 0: series_names = ["Series-%i" % i for i in range(data.shape[1-self.xdim])] bar_width = 1/data.shape[1] * 0.75 for i in range(data.shape[1]): x_values = np.arange(data.shape[0]) + i * bar_width self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0]) self.axis.set_xticks(np.arange(data.shape[0]) + data.shape[1] * bar_width/2) self.axis.set_xticklabels(categories) self.axis.legend(self.bars, series_names, loc=1) return self.axis class ImagePlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.image = None def plot(self, axis=None): dim_count = len(self.array.dimensions) if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if dim_count == 2: return self.plot_2d() elif dim_count == 3: return self.plot_3d() else: return None def plot_2d(self): data = self.array[:] x = self.array.dimensions[0].axis(data.shape[0]) y = self.array.dimensions[1].axis(data.shape[1]) xlabel = create_label(self.array.dimensions[0]) ylabel = create_label(self.array.dimensions[1]) self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.set return self.axis def plot_3d(self): if self.array.shape[2] > 3: print("cannot plot 3d data with more than 3 channels " "in the third dim") return None return self.plot_2d() class LinePlotter(Plotter): def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): super().__init__(file_handler, item, parent) self._dataview = data_view self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id) self.lines = [] self.dim_count = len(self._dataview.full_shape) if xdim == -1: self.xdim = item.best_xdim elif xdim > 2: raise ValueError("LinePlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim def plot(self, maxpoints=100000): self.maxpoints = maxpoints if self.dim_count > 2: return if self.dim_count == 1: self.plot_array_1d() else: self.plot_array_2d() def __add_slider(self): steps = self.array.shape[self.xdim] / self.maxpoints slider_ax = self.fig.add_axes([0.15, 0.025, 0.8, 0.025]) self.slider = Slider(slider_ax, 'Slider', 1., steps, valinit=1., valstep=0.25) self.slider.on_changed(self.__update) def __update(self, val): if len(self.lines) > 0: minimum = val * self.maxpoints - self.maxpoints start = minimum if minimum > 0 else 0 end = val * self.maxpoints self.__draw(start, end) self.fig.canvas.draw_idle() def __draw(self, start, end): if self.dim_count == 1: self.__draw_1d(start, end) else: self.__draw_2d(start, end) def __draw_1d(self, start, end): if start < 0: start = 0 if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] y = self._dataview._buffer[int(start):int(end)] x = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y), start) if len(self.lines) == 0: l, = self.axes.plot(x, y, label=self._item.name) self.lines.append(l) else: self.lines[0].set_ydata(y) self.lines[0].set_xdata(x) self.axes.set_xlim([x[0], x[-1]]) def __draw_2d(self, start, end): if start < 0: start = 0 if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start) line_count = self._dataview.current_shape[1 - self.xdim] line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0) for i, l in enumerate(line_labels): if (self.xdim == 0): y = self._dataview._buffer[int(start):int(end), i] else: y = self._dataview._buffer[i, int(start):int(end)] if len(self.lines) <= i: ll, = self.axis.plot(x, y, label=l) self.lines.append(ll) else: self.lines[i].set_ydata(y) self.lines[i].set_xdata(x) self.axis.set_xlim([x[0], x[-1]]) def plot_array_1d(self): self.__draw_1d(0, self.maxpoints) xlabel = create_label(self.dimensions[self.xdim]) ylabel = create_label(self._item) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.view_changed.emit() def plot_array_2d(self): self.__draw_2d(0, self.maxpoints) xlabel = create_label(self.dimensions[self.xdim]) ylabel = create_label(self._item) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.legend(loc=1) self.view_changed.emit() class PlotContainer(QWidget): def __init__(self, parent=None) -> None: super().__init__(parent=parent) self.setLayout(QVBoxLayout()) self.plotter = None def set_plotter(self, plotter): if not self.layout().isEmpty(): self.layout().removeWidget(self.plotter) self.layout().addWidget(plotter) self.plotter = plotter class PlotScreen(QWidget): close_signal = pyqtSignal() def __init__(self, parent=None) -> None: super().__init__(parent=parent) self._file_handler = FileHandler() self.setLayout(QVBoxLayout()) self._container = PlotContainer(self) self.layout().addWidget(self._container) close_btn = QPushButton("close") close_btn.clicked.connect(self.on_close) self._create_plot_controls() self.layout().addWidget(close_btn) self._data_view = None def _create_plot_controls(self): plot_controls = QGroupBox() plot_controls.setFlat(True) plot_controls.setLayout(QHBoxLayout()) plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self._zoom_slider = QSlider(Qt.Horizontal) self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) self._zoom_slider.setFixedWidth(120) self._zoom_slider.setFixedHeight(20) self._zoom_slider.setTickPosition(QSlider.TicksBelow) self._zoom_slider.setSliderPosition(50) self._zoom_slider.setMinimum(0) self._zoom_slider.setMaximum(100) self._zoom_slider.setTickInterval(25) self._zoom_slider.valueChanged.connect(self.on_zoom) self._pan_slider = QSlider(Qt.Horizontal) self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self._pan_slider.setFixedHeight(20) self._pan_slider.setTickPosition(QSlider.TicksBelow) self._pan_slider.setSliderPosition(0) self._pan_slider.setMinimum(0) self._pan_slider.setMaximum(100) self._pan_slider.setTickInterval(25) self._pan_slider.valueChanged.connect(self.on_pan) pl = QLabel("horiz. pos.:") pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) pl.setMaximumWidth(150) zl = QLabel("zoom:") zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) zl.setMaximumWidth(75) plot_controls.layout().addWidget(pl) plot_controls.layout().addWidget(self._pan_slider) plot_controls.layout().addWidget(zl) plot_controls.layout().addWidget(QLabel("+")) plot_controls.layout().addWidget(self._zoom_slider) plot_controls.layout().addWidget(QLabel("-")) self.layout().addWidget(plot_controls) def on_close(self): self.close_signal.emit() def on_zoom(self, new_position): print("zoom", new_position) def on_pan(self, new_position): print("pan", new_position) def on_view_changed(self): print("view changed!") def plot(self, item): try: self._data_view = DataView(item, self._file_handler) except ValueError as e: communicator.plot_error.emit("error in plotscreen.plot %s" % e) return if self._data_view is None: return while not self._data_view.fully_loaded: self._data_view.request_more() # TODO this is just a test, needs to be removed if item.suggested_plotter == PlotterTypes.LinePlotter: self.plotter = LinePlotter(self._file_handler, item, self._data_view) self.plotter.view_changed.connect(self.on_view_changed) self._container.set_plotter(self.plotter) self.plotter.plot()