from PyQt5.QtWidgets import QHBoxLayout, QPushButton, QVBoxLayout, QWidget from PyQt5.QtCore import QObject, pyqtSignal, Qt import matplotlib matplotlib.use('Qt5Agg') from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg from matplotlib.figure import Figure import numpy as np import matplotlib.pyplot as plt from matplotlib.widgets import Slider import nixio as nix def create_label(entity): label = "" if hasattr(entity, "label"): label += (entity.label if entity.label is not None else "") if len(label) == 0 and hasattr(entity, "name"): label += entity.name if hasattr(entity, "unit") and entity.unit is not None: label += " [%s]" % entity.unit return label class Plotter(object): def show(self): plt.show() class EventPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.sc = None self.dim_count = len(data_array.dimensions) if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 1: raise ValueError("EventPlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim def plot(self, axis=None): if axis is None: self.fig = plt.figure(figsize=[5.5, 2.]) self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if len(self.array.dimensions) == 1: return self.plot_1d() else: return None def plot_1d(self): data = self.array[:] xlabel = create_label(self.array.dimensions[self.xdim]) dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias: ylabel = create_label(self.array) else: ylabel = "" self.sc = self.axis.scatter(data, np.ones(data.shape)) self.axis.set_ylim([0.5, 1.5]) self.axis.set_yticks([1.]) self.axis.set_yticklabels([]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) return self.axis class CategoryPlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.bars = [] if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 2: raise ValueError("CategoryPlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim def plot(self, axis=None): if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if len(self.array.dimensions) == 1: return self.plot_1d() elif len(self.array.dimensions) == 2: return self.plot_2d() else: return None def plot_1d(self): data = self.array[:] dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) else: return None if categories is None: categories = ["Cat-%i" % i for i in range(len(data))] ylabel = create_label(self.array) if len(categories) == 0: raise ValueError("Cannot plot a bar chart without any labels") self.bars.append(self.axis.bar(range(1, len(categories)+1), data, tick_label=categories)) self.axis.set_ylabel(ylabel) return self.axis def plot_2d(self): data = self.array[:] if self.xdim == 1: data = data.T dim = self.array.dimensions[self.xdim] if dim.dimension_type == nix.DimensionType.Set: categories = list(dim.labels) if len(categories) == 0: categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])] dim = self.array.dimensions[1-self.xdim] if dim.dimension_type == nix.DimensionType.Set: series_names = list(dim.labels) if len(series_names) == 0: series_names = ["Series-%i" % i for i in range(data.shape[1-self.xdim])] bar_width = 1/data.shape[1] * 0.75 for i in range(data.shape[1]): x_values = np.arange(data.shape[0]) + i * bar_width self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0]) self.axis.set_xticks(np.arange(data.shape[0]) + data.shape[1] * bar_width/2) self.axis.set_xticklabels(categories) self.axis.legend(self.bars, series_names, loc=1) return self.axis class ImagePlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.image = None def plot(self, axis=None): dim_count = len(self.array.dimensions) if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) else: self.fig = axis.figure self.axis = axis if dim_count == 2: return self.plot_2d() elif dim_count == 3: return self.plot_3d() else: return None def plot_2d(self): data = self.array[:] x = self.array.dimensions[0].axis(data.shape[0]) y = self.array.dimensions[1].axis(data.shape[1]) xlabel = create_label(self.array.dimensions[0]) ylabel = create_label(self.array.dimensions[1]) self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.set return self.axis def plot_3d(self): if self.array.shape[2] > 3: print("cannot plot 3d data with more than 3 channels " "in the third dim") return None return self.plot_2d() class LinePlotter(Plotter): def __init__(self, data_array, xdim=-1): self.array = data_array self.lines = [] self.dim_count = len(data_array.dimensions) if xdim == -1: self.xdim = guess_best_xdim(self.array) elif xdim > 2: raise ValueError("LinePlotter: xdim is larger than 2! " "Cannot plot that kind of data") else: self.xdim = xdim self.fig = None self.axis = None def plot(self, axis=None, maxpoints=100000): self.maxpoints = maxpoints if axis is None: self.fig = plt.figure() self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self.array.name) self.__add_slider() else: self.axis = axis dim_count = len(self.array.dimensions) if dim_count > 2: return if dim_count == 1: return self.plot_array_1d() else: return self.plot_array_2d() def __add_slider(self): steps = self.array.shape[self.xdim] / self.maxpoints slider_ax = self.fig.add_axes([0.15, 0.025, 0.8, 0.025]) self.slider = Slider(slider_ax, 'Slider', 1., steps, valinit=1., valstep=0.25) self.slider.on_changed(self.__update) def __update(self, val): if len(self.lines) > 0: minimum = val * self.maxpoints - self.maxpoints start = minimum if minimum > 0 else 0 end = val * self.maxpoints self.__draw(start, end) self.fig.canvas.draw_idle() def __draw(self, start, end): if self.dim_count == 1: self.__draw_1d(start, end) else: self.__draw_2d(start, end) def __draw_1d(self, start, end): if start < 0: start = 0 if end > self.array.shape[self.xdim]: end = self.array.shape[self.xdim] y = self.array[int(start):int(end)] dim = self.array.dimensions[self.xdim] x = np.asarray(dim.axis(len(y), int(start))) if len(self.lines) == 0: l, = self.axis.plot(x, y, label=self.array.name) self.lines.append(l) else: self.lines[0].set_ydata(y) self.lines[0].set_xdata(x) self.axis.set_xlim([x[0], x[-1]]) def __draw_2d(self, start, end): if start < 0: start = 0 if end > self.array.shape[self.xdim]: end = self.array.shape[self.xdim] x_dimension = self.array.dimensions[self.xdim] x = np.asarray(x_dimension.axis(int(end-start), start)) y_dimension = self.array.dimensions[1-self.xdim] labels = y_dimension.labels if len(labels) == 0: labels = list(map(str, range(self.array.shape[1-self.xdim]))) for i, l in enumerate(labels): if (self.xdim == 0): y = self.array[int(start):int(end), i] else: y = self.array[i, int(start):int(end)] if len(self.lines) <= i: ll, = self.axis.plot(x, y, label=l) self.lines.append(ll) else: self.lines[i].set_ydata(y) self.lines[i].set_xdata(x) self.axis.set_xlim([x[0], x[-1]]) def plot_array_1d(self): self.__draw_1d(0, self.maxpoints) xlabel = create_label(self.array.dimensions[self.xdim]) ylabel = create_label(self.array) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) return self.axis def plot_array_2d(self): self.__draw_2d(0, self.maxpoints) xlabel = create_label(self.array.dimensions[self.xdim]) ylabel = create_label(self.array) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) self.axis.legend(loc=1) return self.axis class MplCanvas(FigureCanvasQTAgg): def __init__(self, parent=None, width=5, height=4, dpi=100): fig = Figure(figsize=(width, height), dpi=dpi) self.axes = fig.add_subplot(111) super(MplCanvas, self).__init__(fig) class PlotScreen(QWidget): close_signal = pyqtSignal() def __init__(self, parent) -> None: super().__init__(parent=parent) sc = MplCanvas(self, width=5, height=4, dpi=100) sc.axes.plot([0,1,2,3,4], [10,1,20,3,40]) self.setLayout(QVBoxLayout()) self.layout().addWidget(sc) close_btn = QPushButton("close") close_btn.clicked.connect(self.on_close) self.layout().addWidget(close_btn) def on_close(self): self.close_signal.emit() def plot(self, item): print("plot!", item)