510 lines
18 KiB
Python
510 lines
18 KiB
Python
from nixview.util.enums import PlotterTypes
|
|
from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget
|
|
from PyQt5.QtCore import pyqtSignal, Qt
|
|
import matplotlib
|
|
matplotlib.use('Qt5Agg')
|
|
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar
|
|
from matplotlib.figure import Figure
|
|
import nixio as nix
|
|
import numpy as np
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except ImportError as e:
|
|
print("cannot omport matplotlib", e)
|
|
from matplotlib.widgets import Slider
|
|
from IPython import embed
|
|
|
|
from nixview.util.file_handler import FileHandler
|
|
from nixview.util.dataview import DataView
|
|
from nixview.communicator import communicator
|
|
|
|
|
|
def create_label(item):
|
|
label = ""
|
|
if hasattr(item, "label"):
|
|
label += (item.label if item.label is not None else "")
|
|
if len(label) == 0 and hasattr(item, "name"):
|
|
label += item.name
|
|
if hasattr(item, "unit") and item.unit is not None:
|
|
label += " [%s]" % item.unit
|
|
return label
|
|
|
|
|
|
class MplCanvas(FigureCanvas):
|
|
view_changed = pyqtSignal()
|
|
|
|
def __init__(self, parent=None, width=5, height=4, dpi=100):
|
|
figure = Figure(figsize=(width, height), dpi=dpi)
|
|
self.axis = figure.add_subplot(111)
|
|
super(MplCanvas, self).__init__(figure)
|
|
self._figure = figure
|
|
self._figure.canvas.mpl_connect('figure_enter_event', self.on_enter_figure)
|
|
self._figure.canvas.mpl_connect('figure_leave_event', self.on_leave_figure)
|
|
self._figure.canvas.mpl_connect('axes_enter_event', self.on_enter_axes)
|
|
self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes)
|
|
self._figure.canvas.mpl_connect('pick_event', self.on_pick)
|
|
|
|
|
|
def on_enter_figure(self, event):
|
|
print('enter_figure', event.canvas.figure)
|
|
# event.canvas.figure.patch.set_facecolor('red')
|
|
# event.canvas.draw()
|
|
|
|
def on_leave_figure(self, event):
|
|
print('leave_figure', event.canvas.figure)
|
|
# event.canvas.figure.patch.set_facecolor('grey')
|
|
# event.canvas.draw()
|
|
|
|
def on_enter_axes(self, event):
|
|
print('enter_axes', event.inaxes)
|
|
# event.inaxes.patch.set_facecolor('yellow')
|
|
# event.canvas.draw()
|
|
|
|
def on_leave_axes(self, event):
|
|
print('leave_axes', event.inaxes)
|
|
# event.inaxes.patch.set_facecolor('white')
|
|
# event.canvas.draw()
|
|
|
|
def on_pick(self, event):
|
|
line = event.artist
|
|
print(line.label)
|
|
xdata, ydata = line.get_data()
|
|
ind = event.ind
|
|
print('on pick line:', np.array([xdata[ind], ydata[ind]]).T)
|
|
|
|
#def clear(self):
|
|
# self.clear()
|
|
|
|
#@property
|
|
#def enter_figure(self):
|
|
# return self._fig
|
|
|
|
|
|
class Plotter(MplCanvas):
|
|
def __init__(self, file_handler, item, data_view, parent=None) -> None:
|
|
super().__init__(parent=parent)
|
|
self._file_handler = file_handler
|
|
self._item = item
|
|
self._dataview = data_view
|
|
|
|
def current_view(self):
|
|
raise NotImplementedError("current_view is not implemented on the current plotter")
|
|
|
|
def is_full_view(self):
|
|
raise NotImplementedError("is_full_view is not implemented on the current plotter")
|
|
|
|
|
|
|
|
class EventPlotter(Plotter):
|
|
|
|
def __init__(self, data_array, xdim=-1):
|
|
self.array = data_array
|
|
self.sc = None
|
|
self.dim_count = len(data_array.dimensions)
|
|
"""
|
|
if xdim == -1:
|
|
self.xdim = guess_best_xdim(self.array)
|
|
elif xdim > 1:
|
|
raise ValueError("EventPlotter: xdim is larger than 2! "
|
|
"Cannot plot that kind of data")
|
|
else:
|
|
self.xdim = xdim
|
|
"""
|
|
|
|
def plot(self, axis=None):
|
|
if axis is None:
|
|
self.fig = plt.figure(figsize=[5.5, 2.])
|
|
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
|
|
self.axis.set_title(self.array.name)
|
|
else:
|
|
self.fig = axis.figure
|
|
self.axis = axis
|
|
if len(self.array.dimensions) == 1:
|
|
return self.plot_1d()
|
|
else:
|
|
return None
|
|
|
|
def plot_1d(self):
|
|
data = self.array[:]
|
|
xlabel = create_label(self.array.dimensions[self.xdim])
|
|
dim = self.array.dimensions[self.xdim]
|
|
if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias:
|
|
ylabel = create_label(self.array)
|
|
else:
|
|
ylabel = ""
|
|
self.sc = self.axis.scatter(data, np.ones(data.shape))
|
|
self.axis.set_ylim([0.5, 1.5])
|
|
self.axis.set_yticks([1.])
|
|
self.axis.set_yticklabels([])
|
|
self.axis.set_xlabel(xlabel)
|
|
self.axis.set_ylabel(ylabel)
|
|
return self.axis
|
|
|
|
|
|
class CategoryPlotter(Plotter):
|
|
|
|
def __init__(self, data_array, xdim=-1):
|
|
self.array = data_array
|
|
self.bars = []
|
|
"""
|
|
if xdim == -1:
|
|
self.xdim = guess_best_xdim(self.array)
|
|
elif xdim > 2:
|
|
raise ValueError("CategoryPlotter: xdim is larger than 2! "
|
|
"Cannot plot that kind of data")
|
|
else:
|
|
self.xdim = xdim
|
|
"""
|
|
|
|
def plot(self, axis=None):
|
|
if axis is None:
|
|
self.fig = plt.figure()
|
|
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
|
|
self.axis.set_title(self.array.name)
|
|
else:
|
|
self.fig = axis.figure
|
|
self.axis = axis
|
|
if len(self.array.dimensions) == 1:
|
|
return self.plot_1d()
|
|
elif len(self.array.dimensions) == 2:
|
|
return self.plot_2d()
|
|
else:
|
|
return None
|
|
|
|
def plot_1d(self):
|
|
data = self.array[:]
|
|
dim = self.array.dimensions[self.xdim]
|
|
categories = None
|
|
if dim.dimension_type == nix.DimensionType.Set:
|
|
categories = list(dim.labels)
|
|
else:
|
|
return None
|
|
if categories is None:
|
|
categories = ["Cat-%i" % i for i in range(len(data))]
|
|
ylabel = create_label(self.array)
|
|
if len(categories) == 0:
|
|
raise ValueError("Cannot plot a bar chart without any labels")
|
|
self.bars.append(self.axis.bar(range(1, len(categories)+1), data,
|
|
tick_label=categories))
|
|
self.axis.set_ylabel(ylabel)
|
|
return self.axis
|
|
|
|
def plot_2d(self):
|
|
data = self.array[:]
|
|
if self.xdim == 1:
|
|
data = data.T
|
|
categories = None
|
|
dim = self.array.dimensions[self.xdim]
|
|
if dim.dimension_type == nix.DimensionType.Set:
|
|
categories = list(dim.labels)
|
|
if len(categories) == 0:
|
|
categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])]
|
|
|
|
dim = self.array.dimensions[1-self.xdim]
|
|
series_names = []
|
|
if dim.dimension_type == nix.DimensionType.Set:
|
|
series_names = list(dim.labels)
|
|
if len(series_names) == 0:
|
|
series_names = ["Series-%i" % i
|
|
for i in range(data.shape[1-self.xdim])]
|
|
|
|
bar_width = 1/data.shape[1] * 0.75
|
|
for i in range(data.shape[1]):
|
|
x_values = np.arange(data.shape[0]) + i * bar_width
|
|
self.bars.append(self.axis.bar(x_values, data[:, i],
|
|
width=bar_width,
|
|
align="center")[0])
|
|
self.axis.set_xticks(np.arange(data.shape[0]) +
|
|
data.shape[1] * bar_width/2)
|
|
self.axis.set_xticklabels(categories)
|
|
self.axis.legend(self.bars, series_names, loc=1)
|
|
return self.axis
|
|
|
|
|
|
class ImagePlotter(Plotter):
|
|
|
|
def __init__(self, data_array, xdim=-1):
|
|
self.array = data_array
|
|
self.image = None
|
|
|
|
def plot(self, axis=None):
|
|
dim_count = len(self.array.dimensions)
|
|
if axis is None:
|
|
self.fig = plt.figure()
|
|
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
|
|
self.axis.set_title(self.array.name)
|
|
else:
|
|
self.fig = axis.figure
|
|
self.axis = axis
|
|
if dim_count == 2:
|
|
return self.plot_2d()
|
|
elif dim_count == 3:
|
|
return self.plot_3d()
|
|
else:
|
|
return None
|
|
|
|
def plot_2d(self):
|
|
data = self.array[:]
|
|
x = self.array.dimensions[0].axis(data.shape[0])
|
|
y = self.array.dimensions[1].axis(data.shape[1])
|
|
xlabel = create_label(self.array.dimensions[0])
|
|
ylabel = create_label(self.array.dimensions[1])
|
|
self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]])
|
|
self.axis.set_xlabel(xlabel)
|
|
self.axis.set_ylabel(ylabel)
|
|
self.axis.set
|
|
return self.axis
|
|
|
|
def plot_3d(self):
|
|
if self.array.shape[2] > 3:
|
|
print("cannot plot 3d data with more than 3 channels "
|
|
"in the third dim")
|
|
return None
|
|
return self.plot_2d()
|
|
|
|
|
|
class LinePlotter(Plotter):
|
|
|
|
def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
|
|
super().__init__(file_handler, item, data_view, parent)
|
|
#self.canvas = FigureCanvas(self.figure)
|
|
#self.toolbar = NavigationToolbar(self.canvas, self)
|
|
|
|
self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
|
|
self._min_x = None
|
|
self._max_x = None
|
|
self._min_x = None
|
|
self.min_x = None
|
|
self.lines = []
|
|
self.dim_count = len(self._dataview.full_shape)
|
|
if xdim == -1:
|
|
self.xdim = item.best_xdim
|
|
elif xdim > 2:
|
|
raise ValueError("LinePlotter: xdim is larger than 2! "
|
|
"Cannot plot that kind of data")
|
|
else:
|
|
self.xdim = xdim
|
|
self.axis.callbacks.connect('xlim_changed', self.on_xlims_change)
|
|
self.axis.callbacks.connect('ylim_changed', self.on_ylims_change)
|
|
|
|
def on_xlims_change(self, event_ax):
|
|
print("updated xlims: ", event_ax.get_xlim())
|
|
|
|
def on_ylims_change(self, event_ax):
|
|
|
|
print("updated ylims: ", event_ax.get_ylim())
|
|
|
|
def current_view(self):
|
|
cv = []
|
|
return cv
|
|
|
|
@property
|
|
def is_full_view(self):
|
|
xlims = self.axis.get_xlim()
|
|
full = self._min_x == xlims[0] and self._max_x == xlims[-1]
|
|
full = full and self._dataview.fully_loaded
|
|
return full
|
|
|
|
def on_zoom_in(self, new_position):
|
|
print("plotter ZOOM In!", new_position)
|
|
|
|
def on_zoom_out(self, new_position):
|
|
print("plotter ZOOM out!", new_position)
|
|
|
|
def plot(self, maxpoints=100000):
|
|
self.maxpoints = maxpoints
|
|
if self.dim_count > 2:
|
|
return
|
|
if self.dim_count == 1:
|
|
self.plot_array_1d()
|
|
else:
|
|
self.plot_array_2d()
|
|
|
|
def __draw(self, start, end):
|
|
if self.dim_count == 1:
|
|
self.__draw_1d(start, end)
|
|
else:
|
|
self.__draw_2d(start, end)
|
|
|
|
def _set_xlims(self, data_xmin, data_xmax):
|
|
if self._min_x is None or data_xmin < self._min_x:
|
|
self._min_x = data_xmin
|
|
if self._max_x is None or data_xmax > self._max_x:
|
|
self._max_x = data_xmax
|
|
|
|
def __draw_1d(self, start, end):
|
|
if start < 0:
|
|
start = 0
|
|
if end > self._dataview.current_shape[self.xdim]:
|
|
end = self._dataview.current_shape[self.xdim]
|
|
|
|
y = self._dataview._buffer[int(start):int(end)]
|
|
x = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y), start)
|
|
self._set_xlims(x[0], x[-1])
|
|
if len(self.lines) == 0:
|
|
l, = self.axis.plot(x, y, label=self._item.name, picker=5)
|
|
self.lines.append(l)
|
|
else:
|
|
self.lines[0].set_ydata(y)
|
|
self.lines[0].set_xdata(x)
|
|
|
|
self.axis.set_xlim([x[0], x[-1]])
|
|
|
|
def __draw_2d(self, start, end):
|
|
if start < 0:
|
|
start = 0
|
|
if end > self._dataview.current_shape[self.xdim]:
|
|
end = self._dataview.current_shape[self.xdim]
|
|
|
|
x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start)
|
|
line_count = self._dataview.current_shape[1 - self.xdim]
|
|
line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0)
|
|
self._set_xlims(x[0], x[-1])
|
|
|
|
for i, l in enumerate(line_labels):
|
|
if (self.xdim == 0):
|
|
y = self._dataview._buffer[int(start):int(end), i]
|
|
else:
|
|
y = self._dataview._buffer[i, int(start):int(end)]
|
|
|
|
if len(self.lines) <= i:
|
|
ll, = self.axis.plot(x, y, label=l, picker=5)
|
|
self.lines.append(ll)
|
|
else:
|
|
self.lines[i].set_ydata(y)
|
|
self.lines[i].set_xdata(x)
|
|
|
|
self.axis.set_xlim([x[0], x[-1]])
|
|
|
|
def plot_array_1d(self):
|
|
self.__draw_1d(0, self.maxpoints)
|
|
xlabel = create_label(self.dimensions[self.xdim])
|
|
ylabel = create_label(self._item)
|
|
self.axis.set_xlabel(xlabel)
|
|
self.axis.set_ylabel(ylabel)
|
|
self.view_changed.emit()
|
|
|
|
def plot_array_2d(self):
|
|
self.__draw_2d(0, self.maxpoints)
|
|
xlabel = create_label(self.dimensions[self.xdim])
|
|
ylabel = create_label(self._item)
|
|
self.axis.set_xlabel(xlabel)
|
|
self.axis.set_ylabel(ylabel)
|
|
self.axis.legend(loc=1)
|
|
self.view_changed.emit()
|
|
|
|
|
|
class PlotContainer(QWidget):
|
|
def __init__(self, parent=None) -> None:
|
|
super().__init__(parent=parent)
|
|
self.setLayout(QVBoxLayout())
|
|
self.plotter = None
|
|
|
|
def set_plotter(self, plotter):
|
|
if not self.layout().isEmpty():
|
|
self.layout().removeWidget(self.plotter)
|
|
self.layout().addWidget(plotter)
|
|
self.plotter = plotter
|
|
|
|
|
|
class PlotScreen(QWidget):
|
|
close_signal = pyqtSignal()
|
|
|
|
def __init__(self, parent=None) -> None:
|
|
super().__init__(parent=parent)
|
|
self._file_handler = FileHandler()
|
|
|
|
self.setLayout(QVBoxLayout())
|
|
self._container = PlotContainer(self)
|
|
self.layout().addWidget(self._container)
|
|
|
|
close_btn = QPushButton("close")
|
|
close_btn.clicked.connect(self.on_close)
|
|
|
|
self._create_plot_controls()
|
|
self.layout().addWidget(close_btn)
|
|
self._data_view = None
|
|
|
|
self.zoom_position = 0
|
|
|
|
def _create_plot_controls(self):
|
|
plot_controls = QGroupBox()
|
|
plot_controls.setFlat(True)
|
|
plot_controls.setLayout(QHBoxLayout())
|
|
plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
|
|
self._zoom_slider = QSlider(Qt.Horizontal)
|
|
self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
|
self._zoom_slider.setFixedWidth(120)
|
|
self._zoom_slider.setFixedHeight(20)
|
|
self._zoom_slider.setTickPosition(QSlider.TicksBelow)
|
|
self._zoom_slider.setSliderPosition(50)
|
|
self._zoom_slider.setMinimum(0)
|
|
self._zoom_slider.setMaximum(100)
|
|
self._zoom_slider.setTickInterval(25)
|
|
self._zoom_slider.valueChanged.connect(self.on_zoom)
|
|
|
|
self._pan_slider = QSlider(Qt.Horizontal)
|
|
self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
|
|
self._pan_slider.setFixedHeight(20)
|
|
self._pan_slider.setTickPosition(QSlider.TicksBelow)
|
|
self._pan_slider.setSliderPosition(0)
|
|
self._pan_slider.setMinimum(0)
|
|
self._pan_slider.setMaximum(100)
|
|
self._pan_slider.setTickInterval(25)
|
|
self._pan_slider.valueChanged.connect(self.on_pan)
|
|
|
|
pl = QLabel("horiz. pos.:")
|
|
pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
|
|
pl.setMaximumWidth(150)
|
|
zl = QLabel("zoom:")
|
|
zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
|
|
zl.setMaximumWidth(75)
|
|
plot_controls.layout().addWidget(pl)
|
|
plot_controls.layout().addWidget(self._pan_slider)
|
|
plot_controls.layout().addWidget(zl)
|
|
plot_controls.layout().addWidget(QLabel("+"))
|
|
plot_controls.layout().addWidget(self._zoom_slider)
|
|
plot_controls.layout().addWidget(QLabel("-"))
|
|
self.layout().addWidget(plot_controls)
|
|
|
|
def on_close(self):
|
|
self.close_signal.emit()
|
|
|
|
def on_zoom(self, new_position):
|
|
if self.zoom_position < new_position:
|
|
self.plotter.on_zoom_out(new_position)
|
|
else:
|
|
self.plotter.on_zoom_out(new_position)
|
|
self.zoom_position = new_position
|
|
|
|
def on_pan(self, new_position):
|
|
|
|
print("pan", new_position)
|
|
|
|
def on_view_changed(self):
|
|
print("view changed!")
|
|
print(self.plotter.current_view())
|
|
if self.plotter.is_full_view:
|
|
self._zoom_slider.setSliderPosition(100)
|
|
self._pan_slider.setEnabled(not self.plotter.is_full_view)
|
|
|
|
|
|
def plot(self, item):
|
|
try:
|
|
self._data_view = DataView(item, self._file_handler)
|
|
except ValueError as e:
|
|
communicator.plot_error.emit("error in plotscreen.plot %s" % e)
|
|
return
|
|
if self._data_view is None:
|
|
return
|
|
#while not self._data_view.fully_loaded:
|
|
# self._data_view.request_more() # TODO this is just a test, needs to be removed
|
|
if item.suggested_plotter == PlotterTypes.LinePlotter:
|
|
self.plotter = LinePlotter(self._file_handler, item, self._data_view)
|
|
self.plotter.view_changed.connect(self.on_view_changed)
|
|
self._container.set_plotter(self.plotter)
|
|
self.plotter.plot(maxpoints=10000)
|
|
self._zoom_slider.setSliderPosition(100)
|
|
self.zoom_position = 100
|
|
|