from nixview.util.enums import NodeType, PlotterTypes from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget from PyQt5.QtCore import pyqtSignal, Qt import matplotlib matplotlib.use('Qt5Agg') from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.figure import Figure import numpy as np try: import matplotlib.pyplot as plt except ImportError as e: print("cannot import matplotlib, headless mode?", e) from matplotlib.widgets import Slider from nixview.util.file_handler import FileHandler from nixview.util.dataview import DataView from nixview.communicator import communicator def create_label(item): label = "" if hasattr(item, "label"): label += (item.label if item.label is not None else "") if len(label) == 0 and hasattr(item, "name"): label += item.name if hasattr(item, "unit") and item.unit is not None: label += " [%s]" % item.unit return label class MplCanvas(FigureCanvas): """ MplCanvas extends FigureCanvasQtAgg Matplotlib backend. Args: FigureCanvas ([type]): [description] """ view_changed = pyqtSignal() def __init__(self, parent=None, width=5, height=4, dpi=100): figure = Figure(figsize=(width, height), dpi=dpi) self.axis = figure.add_subplot(111) super(MplCanvas, self).__init__(figure) self._figure = figure self._figure.canvas.mpl_connect('figure_enter_event', self.on_enter_figure) self._figure.canvas.mpl_connect('figure_leave_event', self.on_leave_figure) self._figure.canvas.mpl_connect('axes_enter_event', self.on_enter_axes) self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes) self._figure.canvas.mpl_connect('pick_event', self.on_pick) def on_enter_figure(self, event): # print('enter_figure', event.canvas.figure) # event.canvas.figure.patch.set_facecolor('red') # event.canvas.draw() pass def on_leave_figure(self, event): # print('leave_figure', event.canvas.figure) # event.canvas.figure.patch.set_facecolor('grey') # event.canvas.draw() pass def on_enter_axes(self, event): # print('enter_axes', event.inaxes) # event.inaxes.patch.set_facecolor('yellow') # event.canvas.draw() pass def on_leave_axes(self, event): # print('leave_axes', event.inaxes) # event.inaxes.patch.set_facecolor('white') # event.canvas.draw() pass def on_pick(self, event): line = event.artist print(line.label) xdata, ydata = line.get_data() ind = event.ind print('on pick line:', np.array([xdata[ind], ydata[ind]]).T) #def clear(self): # self.clear() #@property #def enter_figure(self): # return self._fig class Plotter(MplCanvas): """ Abstract class for visual display of data (plotting) Inheriting classes need to implement the current_view and is_full_view methods Plotter extends MplCanvas. """ def __init__(self, file_handler, item, data_view, parent=None) -> None: super().__init__(parent=parent) self._file_handler = file_handler self._item = item self._dataview = data_view @property def is_full_view(self): raise NotImplementedError("is_full_view is not implemented on the current plotter") @property def can_pan_horizontally(self): raise NotImplementedError("can_pan_left is not implemented on the current plotter") class EventPlotter(Plotter): def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): super().__init__(file_handler, item, data_view, parent) self.dim_count = len(self._dataview.full_shape) if xdim == -1: self.xdim = self._item.best_xdim else: self.xdim = xdim self._abs_xmin = 0 self._abs_xmax = self._dataview.full_shape[self.xdim] self._view_xmin = 0 self._view_xmax = 0 self._zoom_level = 0 self._segment_length = 0 self.sc = None @property def is_full_view(self): full = self._data_xmin == self._view_xmin and self._data_xmax == self._view_xmax return full @property def can_pan_horizontally(self): return self.can_pan_left or self.can_pan_right @property def can_pan_left(self): return self._view_xmin > self._abs_xmin @property def can_pan_right(self): return self._view_xmax < self._abs_xmax @property def horizontal_pan_position(self): return self._view_xmax/self._abs_xmax def horizontal_pan_to_position(self, new_position, zoomlevel): new_xmax = int(np.min([np.ceil(new_position * self._abs_xmax), self._abs_xmax])) segment_length = zoomlevel * self._abs_xmax start = np.max([0, new_xmax - segment_length]) while not self._dataview.fully_loaded and new_xmax < self._dataview.current_shape[self.xdim]: self._dataview.request_more() self.plot(start, zoomlevel) def plot(self, start=0, zoomlevel=1.0): if zoomlevel > 1: zoomlevel = 1.0 self._segment_length = zoomlevel * self._abs_xmax self._zoom_level = zoomlevel if self.dim_count == 1: return self.plot_1d(start) else: return None # FIXME 2D events? def plot_1d(self, start=0, zoomlevel=1.0): if start < 0: start = 0 end = start + self._segment_length if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] y_values = self._dataview._buffer[int(start):int(end)] x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y_values), int(start)) data = self._dataview._buffer[:] dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id) xlabel = create_label(dimensions[self.xdim]) #dim = dimensions[self.xdim] #if "Range"dim.type == dRange and not dim.is_alias: # ylabel = create_label(self.array) #else: # ylabel = "" if self.sc is None: label = self._item.name self.sc = self.axis.scatter(data, np.ones(data.shape)) self.sc.set_pickradius(5) else: self.sc.set_data(data, np.ones(data.shape)) self.figure.canvas.draw_idle() # self.axis.set_ylim([np.min(y_values), np.max(y_values)]) self.axis.set_xlim([x_values[0], x_values[-1]]) self.axis.set_ylim([0.5, 1.5]) self.axis.set_yticks([1.]) self.axis.set_yticklabels([]) self.axis.set_xlabel(xlabel) # self.axis.set_ylabel(ylabel) return self.axis class CategoryPlotter(Plotter): def __init__(self, file_handler, item, data_view, parent=None): super().__init__(file_handler, item, data_view, parent) self.dim_count = len(self._dataview.full_shape) self._xdim = self._item.best_xdim self.bars = [] def plot(self): if len(self._dataview.full_shape) == 1: return self.plot_1d() elif len(self._dataview.full_shape) == 2: return self.plot_2d() else: return None def plot_1d(self): categories = None ylabel = create_label(self._item) categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim]) if categories is None or len(categories) == 0: categories = ["Cat-%i" % i for i in range(len(categories))] ylabel = create_label(self._item) self.bars.append(self.axis.bar(range(1, len(categories) + 1), self._dataview._buffer, tick_label=categories)) self.axis.set_ylabel(ylabel) return self.axis def plot_2d(self): categories = None ylabel = create_label(self._item) data = self._dataview._buffer if self._item.best_xdim == 1: data = data.T categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim]) if categories is None or len(categories) == 0: categories = ["Cat-%i" % i for i in range(self._item.shape[self._item.best_xdim])] series = self._file_handler.request_axis(self._item.block_id, self._item.id, 1 - self._item.best_xdim, self._item.shape[1 - self._item.best_xdim]) if len(series) == 0: series = ["Series-%i" % i for i in range(self._item.shape[1 - self._item.best_xdim])] ylabel = create_label(self._item) bar_width = 1/len(series) * 0.75 for i in range(len(series)): x_values = np.arange(len(categories)) + i * bar_width self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0]) self.axis.set_xticks(np.arange(len(categories)) + len(series) * bar_width/2) self.axis.set_xticklabels(categories) self.axis.legend(self.bars, series, loc=1) self.axis.set_ylabel(ylabel) class ImagePlotter(Plotter): def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): super().__init__(file_handler, item, data_view, parent) self._dim_count = len(self._dataview.full_shape) if self._dim_count > 3: raise ValueError("ImagePlotter cannot plot data with more than 3 dimensions! Shape of %s is %s " %(self._item.name, str(self._dataview.full_shape))) self.image = None def plot(self): while not self._dataview.fully_loaded: self._dataview.request_more() self.axis.set_title(self._item.name) if self._dim_count == 2: return self.plot_2d() elif self._dim_count == 3: return self.plot_3d() def plot_2d(self): data = self._dataview._buffer[:] if len(data.shape) == 3: data = data.astype("uint8") x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, self._dataview.full_shape[0], 0) y_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 1, self._dataview.full_shape[1], 0) dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id) xlabel = create_label(dimensions[0]) ylabel = create_label(dimensions[1]) self.image = self.axis.imshow(data, extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) return self.axis def plot_3d(self): if self._dataview.full_shape[2] > 3: print("cannot plot 3d data with more than 3 channels " "in the third dim") return None return self.plot_2d() class LinePlotter(Plotter): """ LinePlotter extends and implements the Plotter class. It shows line plot data. Either single or multiple line Args: Plotter ([type]): [description] """ def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): super().__init__(file_handler, item, data_view, parent) #self.canvas = FigureCanvas(self.figure) #self.toolbar = NavigationToolbar(self.canvas, self) self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id) self.lines = [] self.dim_count = len(self._dataview.full_shape) if xdim == -1: self.xdim = item.best_xdim elif xdim > 2: raise ValueError("LinePlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim self._data_xmin = 0 self._data_xmax = self._dataview.current_shape[self.xdim] self._abs_xmin = 0 self._abs_xmax = self._dataview.full_shape[self.xdim] self._view_xmin = 0 self._view_xmax = 0 self._zoom_level = 0 self._segment_length = 0 # self.axis.callbacks.connect('xlim_changed', self.on_xlims_change) # self.axis.callbacks.connect('ylim_changed', self.on_ylims_change) @property def is_full_view(self): full = self._data_xmin == self._view_xmin and self._data_xmax == self._view_xmax return full @property def can_pan_horizontally(self): return self.can_pan_left or self.can_pan_right @property def can_pan_left(self): return self._view_xmin > self._abs_xmin @property def can_pan_right(self): return self._view_xmax < self._abs_xmax @property def horizontal_pan_position(self): return self._view_xmax/self._abs_xmax def horizontal_pan_to_position(self, new_position, zoomlevel): new_xmax = int(np.min([np.ceil(new_position * self._abs_xmax), self._abs_xmax])) segment_length = zoomlevel * self._abs_xmax start = np.max([0, new_xmax - segment_length]) while not self._dataview.fully_loaded and new_xmax < self._dataview.current_shape[self.xdim]: self._dataview.request_more() self.plot(start, zoomlevel) def on_zoom_in(self, new_position): print("plotter ZOOM In!", new_position) def on_zoom_out(self, new_position): print("plotter ZOOM out!", new_position) def plot(self, start=0, zoomlevel=1.0): if zoomlevel > 1: zoomlevel = 1.0 self._segment_length = zoomlevel * self._abs_xmax self._zoom_level = zoomlevel if self.dim_count > 2: return if self.dim_count == 1: self.plot_array_1d(start) else: self.plot_array_2d(start) def _update_abs_extremes(self, display_x_min, display_xmax): if self._data_xmin is None or display_x_min < self._data_xmin: self._data_xmin = display_x_min if self._data_xmax is None or display_xmax > self._data_xmax: self._data_xmax = display_xmax def _update_current_view(self, current_xmin, current_xmax): self._view_xmax = current_xmax self._view_xmin = current_xmin def __draw_1d(self, start, end): """ draw the data from start to end index. Args: start (int): start index in the data end (int): end index in the data """ if start < 0: start = 0 if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] y_values = self._dataview._buffer[int(start):int(end)] x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y_values), int(start)) self._update_abs_extremes(start, end) self._update_current_view(start, end) if len(self.lines) == 0: label = self._item.name l, = self.axis.plot(x_values, y_values, label=label) l.set_pickradius(5) self.lines.append(l) else: self.lines[-1].set_data(x_values[:len(y_values)], y_values) self.figure.canvas.draw_idle() self.axis.set_ylim([np.min(y_values), np.max(y_values)]) self.axis.set_xlim([x_values[0], x_values[-1]]) def __draw_2d(self, start, end): if start < 0: start = 0 if end > self._dataview.current_shape[self.xdim]: end = self._dataview.current_shape[self.xdim] x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start) line_count = self._dataview.current_shape[1 - self.xdim] line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0) self._update_abs_extremes(x[0], x[-1]) for i, l in enumerate(line_labels): if (self.xdim == 0): y = self._dataview._buffer[int(start):int(end), i] else: y = self._dataview._buffer[i, int(start):int(end)] if len(self.lines) <= i: ll, = self.axis.plot(x, y, label=l) ll.set_pickradius(5) self.lines.append(ll) else: self.lines[i].set_ydata(y) self.lines[i].set_xdata(x) self.axis.legend() self.axis.set_xlim([x[0], x[-1]]) def plot_array_1d(self, start=0): self.__draw_1d(start, start + self._segment_length) xlabel = create_label(self.dimensions[self.xdim]) ylabel = create_label(self._item) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.view_changed.emit() def plot_array_2d(self, start=0): self.__draw_2d(start, start + self._segment_length) xlabel = create_label(self.dimensions[self.xdim]) ylabel = create_label(self._item) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.legend(loc=1) self.view_changed.emit() class PlotContainer(QWidget): def __init__(self, parent=None) -> None: super().__init__(parent=parent) self.setLayout(QVBoxLayout()) self.plotter = None def set_plotter(self, plotter): if not self.layout().isEmpty(): self.layout().removeWidget(self.plotter) self.layout().addWidget(plotter) self.plotter = plotter class PlotScreen(QWidget): close_signal = pyqtSignal() def __init__(self, parent=None) -> None: super().__init__(parent=parent) self._file_handler = FileHandler() self.setLayout(QVBoxLayout()) self._container = PlotContainer(self) self.layout().addWidget(self._container) close_btn = QPushButton("close") close_btn.clicked.connect(self.on_close) self._create_plot_controls() self.layout().addWidget(close_btn) self._data_view = None self._software_slide = False self.plotter = None def _create_plot_controls(self): plot_controls = QGroupBox() plot_controls.setFlat(True) plot_controls.setLayout(QHBoxLayout()) plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self._zoom_slider = QSlider(Qt.Horizontal) self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) self._zoom_slider.setFixedWidth(120) self._zoom_slider.setFixedHeight(20) self._zoom_slider.setTickPosition(QSlider.TicksBelow) self._zoom_slider.setSliderPosition(500) self._zoom_slider.setMinimum(1) self._zoom_slider.setMaximum(1000) self._zoom_slider.setTickInterval(250) self._zoom_slider.valueChanged.connect(self.on_zoom) self._pan_slider = QSlider(Qt.Horizontal) self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) self._pan_slider.setFixedHeight(20) self._pan_slider.setTickPosition(QSlider.TicksBelow) self._pan_slider.setSliderPosition(0) self._pan_slider.setMinimum(1) self._pan_slider.setMaximum(1000) self._pan_slider.setTickInterval(250) self._pan_slider.valueChanged.connect(self.on_pan) pl = QLabel("horiz. pos.:") pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) pl.setMaximumWidth(150) zl = QLabel("zoom:") zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) zl.setMaximumWidth(75) plot_controls.layout().addWidget(pl) plot_controls.layout().addWidget(self._pan_slider) plot_controls.layout().addWidget(zl) plot_controls.layout().addWidget(QLabel("+")) plot_controls.layout().addWidget(self._zoom_slider) plot_controls.layout().addWidget(QLabel("-")) self.layout().addWidget(plot_controls) def on_close(self): self.close_signal.emit() def on_zoom(self, new_position): if self.plotter is None: return if self._software_slide: self._software_slide = False return self.plotter.horizontal_pan_to_position(self._pan_slider.sliderPosition()/1000, self._zoom_slider.sliderPosition()/1000) def on_pan(self, new_position): # self._pan_slider.setEnabled(False) if self._software_slide: self._software_slide = False return self.plotter.horizontal_pan_to_position(new_position/1000., self._zoom_slider.sliderPosition()/1000.) def on_view_changed(self): self._pan_slider.setEnabled(self.plotter.can_pan_horizontally) def plot_data_array(self, item): try: self._data_view = DataView(item, self._file_handler) except ValueError as e: print("ping!") communicator.plot_error.emit("error in plotscreen.plot %s" % e) return if self._data_view is None: return if item.suggested_plotter == PlotterTypes.LinePlotter: self._zoom_slider.setEnabled(True) self._pan_slider.setEnabled(True) zoom_slider_position = np.round(1000 / (self._data_view.full_shape[item.best_xdim] / 50000)) self._software_slide = True self._zoom_slider.setSliderPosition(zoom_slider_position) self.plotter = LinePlotter(self._file_handler, item, self._data_view) self.plotter.view_changed.connect(self.on_view_changed) self._container.set_plotter(self.plotter) self.plotter.plot(zoomlevel=zoom_slider_position/1000.) self._software_slide = False elif item.suggested_plotter == PlotterTypes.ImagePlotter: self._zoom_slider.setEnabled(False) self._pan_slider.setEnabled(False) self.plotter = ImagePlotter(self._file_handler, item, self._data_view) self._container.set_plotter(self.plotter) self.plotter.plot() elif item.suggested_plotter == PlotterTypes.EventPlotter: self._zoom_slider.setEnabled(False) self._pan_slider.setEnabled(False) self.plotter = EventPlotter(self._file_handler, item, self._data_view) self._container.set_plotter(self.plotter) self.plotter.plot() elif item.suggested_plotter == PlotterTypes.CategoryPlotter: self._zoom_slider.setEnabled(False) self._pan_slider.setEnabled(False) self.plotter = CategoryPlotter(self._file_handler, item, self._data_view) self._container.set_plotter(self.plotter) self.plotter.plot() else: self._container.set_plotter(None) def plot_tag(self, item): pass def plot(self, item): if item.entity_type == NodeType.DataArray: self.plot_data_array(item) elif item.entity_type == NodeType.MultiTag or item.entity_type == NodeType.Tag: self.plot_tag(item) else: communicator.plot_error.emit("error in plotscreen.plot cannot plot entity of type%s" % item.entity_type) return