nixview-python/nixview/ui/plotscreen.py

609 lines
23 KiB
Python

from nixview.util.enums import NodeType, PlotterTypes
from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget
from PyQt5.QtCore import pyqtSignal, Qt
import matplotlib
matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import numpy as np
try:
import matplotlib.pyplot as plt
except ImportError as e:
print("cannot import matplotlib, headless mode?", e)
from matplotlib.widgets import Slider
from nixview.util.file_handler import FileHandler
from nixview.util.dataview import DataView
from nixview.communicator import communicator
def create_label(item):
label = ""
if hasattr(item, "label"):
label += (item.label if item.label is not None else "")
if len(label) == 0 and hasattr(item, "name"):
label += item.name
if hasattr(item, "unit") and item.unit is not None:
label += " [%s]" % item.unit
return label
class MplCanvas(FigureCanvas):
""" MplCanvas extends FigureCanvasQtAgg Matplotlib backend.
Args:
FigureCanvas ([type]): [description]
"""
view_changed = pyqtSignal()
def __init__(self, parent=None, width=5, height=4, dpi=100):
figure = Figure(figsize=(width, height), dpi=dpi)
self.axis = figure.add_subplot(111)
super(MplCanvas, self).__init__(figure)
self._figure = figure
self._figure.canvas.mpl_connect('figure_enter_event', self.on_enter_figure)
self._figure.canvas.mpl_connect('figure_leave_event', self.on_leave_figure)
self._figure.canvas.mpl_connect('axes_enter_event', self.on_enter_axes)
self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes)
self._figure.canvas.mpl_connect('pick_event', self.on_pick)
def on_enter_figure(self, event):
# print('enter_figure', event.canvas.figure)
# event.canvas.figure.patch.set_facecolor('red')
# event.canvas.draw()
pass
def on_leave_figure(self, event):
# print('leave_figure', event.canvas.figure)
# event.canvas.figure.patch.set_facecolor('grey')
# event.canvas.draw()
pass
def on_enter_axes(self, event):
# print('enter_axes', event.inaxes)
# event.inaxes.patch.set_facecolor('yellow')
# event.canvas.draw()
pass
def on_leave_axes(self, event):
# print('leave_axes', event.inaxes)
# event.inaxes.patch.set_facecolor('white')
# event.canvas.draw()
pass
def on_pick(self, event):
line = event.artist
print(line.label)
xdata, ydata = line.get_data()
ind = event.ind
print('on pick line:', np.array([xdata[ind], ydata[ind]]).T)
#def clear(self):
# self.clear()
#@property
#def enter_figure(self):
# return self._fig
class Plotter(MplCanvas):
""" Abstract class for visual display of data (plotting)
Inheriting classes need to implement the current_view and is_full_view methods
Plotter extends MplCanvas.
"""
def __init__(self, file_handler, item, data_view, parent=None) -> None:
super().__init__(parent=parent)
self._file_handler = file_handler
self._item = item
self._dataview = data_view
@property
def is_full_view(self):
raise NotImplementedError("is_full_view is not implemented on the current plotter")
@property
def can_pan_horizontally(self):
raise NotImplementedError("can_pan_left is not implemented on the current plotter")
class EventPlotter(Plotter):
def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
super().__init__(file_handler, item, data_view, parent)
self.dim_count = len(self._dataview.full_shape)
if xdim == -1:
self.xdim = self._item.best_xdim
else:
self.xdim = xdim
self._abs_xmin = 0
self._abs_xmax = self._dataview.full_shape[self.xdim]
self._view_xmin = 0
self._view_xmax = 0
self._zoom_level = 0
self._segment_length = 0
self.sc = None
@property
def is_full_view(self):
full = self._data_xmin == self._view_xmin and self._data_xmax == self._view_xmax
return full
@property
def can_pan_horizontally(self):
return self.can_pan_left or self.can_pan_right
@property
def can_pan_left(self):
return self._view_xmin > self._abs_xmin
@property
def can_pan_right(self):
return self._view_xmax < self._abs_xmax
@property
def horizontal_pan_position(self):
return self._view_xmax/self._abs_xmax
def horizontal_pan_to_position(self, new_position, zoomlevel):
new_xmax = int(np.min([np.ceil(new_position * self._abs_xmax), self._abs_xmax]))
segment_length = zoomlevel * self._abs_xmax
start = np.max([0, new_xmax - segment_length])
while not self._dataview.fully_loaded and new_xmax < self._dataview.current_shape[self.xdim]:
self._dataview.request_more()
self.plot(start, zoomlevel)
def plot(self, start=0, zoomlevel=1.0):
if zoomlevel > 1:
zoomlevel = 1.0
self._segment_length = zoomlevel * self._abs_xmax
self._zoom_level = zoomlevel
if self.dim_count == 1:
return self.plot_1d(start)
else:
return None # FIXME 2D events?
def plot_1d(self, start=0, zoomlevel=1.0):
if start < 0:
start = 0
end = start + self._segment_length
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
y_values = self._dataview._buffer[int(start):int(end)]
x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y_values), int(start))
data = self._dataview._buffer[:]
dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
xlabel = create_label(dimensions[self.xdim])
#dim = dimensions[self.xdim]
#if "Range"dim.type == dRange and not dim.is_alias:
# ylabel = create_label(self.array)
#else:
# ylabel = ""
if self.sc is None:
label = self._item.name
self.sc = self.axis.scatter(data, np.ones(data.shape))
self.sc.set_pickradius(5)
else:
self.sc.set_data(data, np.ones(data.shape))
self.figure.canvas.draw_idle()
# self.axis.set_ylim([np.min(y_values), np.max(y_values)])
self.axis.set_xlim([x_values[0], x_values[-1]])
self.axis.set_ylim([0.5, 1.5])
self.axis.set_yticks([1.])
self.axis.set_yticklabels([])
self.axis.set_xlabel(xlabel)
# self.axis.set_ylabel(ylabel)
return self.axis
class CategoryPlotter(Plotter):
def __init__(self, file_handler, item, data_view, parent=None):
super().__init__(file_handler, item, data_view, parent)
self.dim_count = len(self._dataview.full_shape)
self._xdim = self._item.best_xdim
self.bars = []
def plot(self):
if len(self._dataview.full_shape) == 1:
return self.plot_1d()
elif len(self._dataview.full_shape) == 2:
return self.plot_2d()
else:
return None
def plot_1d(self):
categories = None
ylabel = create_label(self._item)
categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim])
if categories is None or len(categories) == 0:
categories = ["Cat-%i" % i for i in range(len(categories))]
ylabel = create_label(self._item)
self.bars.append(self.axis.bar(range(1, len(categories) + 1), self._dataview._buffer, tick_label=categories))
self.axis.set_ylabel(ylabel)
return self.axis
def plot_2d(self):
categories = None
ylabel = create_label(self._item)
data = self._dataview._buffer
if self._item.best_xdim == 1:
data = data.T
categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim])
if categories is None or len(categories) == 0:
categories = ["Cat-%i" % i for i in range(self._item.shape[self._item.best_xdim])]
series = self._file_handler.request_axis(self._item.block_id, self._item.id, 1 - self._item.best_xdim, self._item.shape[1 - self._item.best_xdim])
if len(series) == 0:
series = ["Series-%i" % i for i in range(self._item.shape[1 - self._item.best_xdim])]
ylabel = create_label(self._item)
bar_width = 1/len(series) * 0.75
for i in range(len(series)):
x_values = np.arange(len(categories)) + i * bar_width
self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0])
self.axis.set_xticks(np.arange(len(categories)) + len(series) * bar_width/2)
self.axis.set_xticklabels(categories)
self.axis.legend(self.bars, series, loc=1)
self.axis.set_ylabel(ylabel)
class ImagePlotter(Plotter):
def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
super().__init__(file_handler, item, data_view, parent)
self._dim_count = len(self._dataview.full_shape)
if self._dim_count > 3:
raise ValueError("ImagePlotter cannot plot data with more than 3 dimensions! Shape of %s is %s " %(self._item.name, str(self._dataview.full_shape)))
self.image = None
def plot(self):
while not self._dataview.fully_loaded:
self._dataview.request_more()
self.axis.set_title(self._item.name)
if self._dim_count == 2:
return self.plot_2d()
elif self._dim_count == 3:
return self.plot_3d()
def plot_2d(self):
data = self._dataview._buffer[:]
if len(data.shape) == 3:
data = data.astype("uint8")
x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, self._dataview.full_shape[0], 0)
y_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 1, self._dataview.full_shape[1], 0)
dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
xlabel = create_label(dimensions[0])
ylabel = create_label(dimensions[1])
self.image = self.axis.imshow(data, extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]])
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
return self.axis
def plot_3d(self):
if self._dataview.full_shape[2] > 3:
print("cannot plot 3d data with more than 3 channels "
"in the third dim")
return None
return self.plot_2d()
class LinePlotter(Plotter):
""" LinePlotter extends and implements the Plotter class. It shows line plot data. Either single or multiple line
Args:
Plotter ([type]): [description]
"""
def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
super().__init__(file_handler, item, data_view, parent)
#self.canvas = FigureCanvas(self.figure)
#self.toolbar = NavigationToolbar(self.canvas, self)
self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
self.lines = []
self.dim_count = len(self._dataview.full_shape)
if xdim == -1:
self.xdim = item.best_xdim
elif xdim > 2:
raise ValueError("LinePlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
self._data_xmin = 0
self._data_xmax = self._dataview.current_shape[self.xdim]
self._abs_xmin = 0
self._abs_xmax = self._dataview.full_shape[self.xdim]
self._view_xmin = 0
self._view_xmax = 0
self._zoom_level = 0
self._segment_length = 0
# self.axis.callbacks.connect('xlim_changed', self.on_xlims_change)
# self.axis.callbacks.connect('ylim_changed', self.on_ylims_change)
@property
def is_full_view(self):
full = self._data_xmin == self._view_xmin and self._data_xmax == self._view_xmax
return full
@property
def can_pan_horizontally(self):
return self.can_pan_left or self.can_pan_right
@property
def can_pan_left(self):
return self._view_xmin > self._abs_xmin
@property
def can_pan_right(self):
return self._view_xmax < self._abs_xmax
@property
def horizontal_pan_position(self):
return self._view_xmax/self._abs_xmax
def horizontal_pan_to_position(self, new_position, zoomlevel):
new_xmax = int(np.min([np.ceil(new_position * self._abs_xmax), self._abs_xmax]))
segment_length = zoomlevel * self._abs_xmax
start = np.max([0, new_xmax - segment_length])
while not self._dataview.fully_loaded and new_xmax < self._dataview.current_shape[self.xdim]:
self._dataview.request_more()
self.plot(start, zoomlevel)
def on_zoom_in(self, new_position):
print("plotter ZOOM In!", new_position)
def on_zoom_out(self, new_position):
print("plotter ZOOM out!", new_position)
def plot(self, start=0, zoomlevel=1.0):
if zoomlevel > 1:
zoomlevel = 1.0
self._segment_length = zoomlevel * self._abs_xmax
self._zoom_level = zoomlevel
if self.dim_count > 2:
return
if self.dim_count == 1:
self.plot_array_1d(start)
else:
self.plot_array_2d(start)
def _update_abs_extremes(self, display_x_min, display_xmax):
if self._data_xmin is None or display_x_min < self._data_xmin:
self._data_xmin = display_x_min
if self._data_xmax is None or display_xmax > self._data_xmax:
self._data_xmax = display_xmax
def _update_current_view(self, current_xmin, current_xmax):
self._view_xmax = current_xmax
self._view_xmin = current_xmin
def __draw_1d(self, start, end):
""" draw the data from start to end index.
Args:
start (int): start index in the data
end (int): end index in the data
"""
if start < 0:
start = 0
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
y_values = self._dataview._buffer[int(start):int(end)]
x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y_values), int(start))
self._update_abs_extremes(start, end)
self._update_current_view(start, end)
if len(self.lines) == 0:
label = self._item.name
l, = self.axis.plot(x_values, y_values, label=label)
l.set_pickradius(5)
self.lines.append(l)
else:
self.lines[-1].set_data(x_values[:len(y_values)], y_values)
self.figure.canvas.draw_idle()
self.axis.set_ylim([np.min(y_values), np.max(y_values)])
self.axis.set_xlim([x_values[0], x_values[-1]])
def __draw_2d(self, start, end):
if start < 0:
start = 0
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start)
line_count = self._dataview.current_shape[1 - self.xdim]
line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0)
self._update_abs_extremes(x[0], x[-1])
for i, l in enumerate(line_labels):
if (self.xdim == 0):
y = self._dataview._buffer[int(start):int(end), i]
else:
y = self._dataview._buffer[i, int(start):int(end)]
if len(self.lines) <= i:
ll, = self.axis.plot(x, y, label=l)
ll.set_pickradius(5)
self.lines.append(ll)
else:
self.lines[i].set_ydata(y)
self.lines[i].set_xdata(x)
self.axis.legend()
self.axis.set_xlim([x[0], x[-1]])
def plot_array_1d(self, start=0):
self.__draw_1d(start, start + self._segment_length)
xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.view_changed.emit()
def plot_array_2d(self, start=0):
self.__draw_2d(start, start + self._segment_length)
xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.axis.legend(loc=1)
self.view_changed.emit()
class PlotContainer(QWidget):
def __init__(self, parent=None) -> None:
super().__init__(parent=parent)
self.setLayout(QVBoxLayout())
self.plotter = None
def set_plotter(self, plotter):
if not self.layout().isEmpty():
self.layout().removeWidget(self.plotter)
self.layout().addWidget(plotter)
self.plotter = plotter
class PlotScreen(QWidget):
close_signal = pyqtSignal()
def __init__(self, parent=None) -> None:
super().__init__(parent=parent)
self._file_handler = FileHandler()
self.setLayout(QVBoxLayout())
self._container = PlotContainer(self)
self.layout().addWidget(self._container)
close_btn = QPushButton("close")
close_btn.clicked.connect(self.on_close)
self._create_plot_controls()
self.layout().addWidget(close_btn)
self._data_view = None
self._software_slide = False
self.plotter = None
def _create_plot_controls(self):
plot_controls = QGroupBox()
plot_controls.setFlat(True)
plot_controls.setLayout(QHBoxLayout())
plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._zoom_slider = QSlider(Qt.Horizontal)
self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
self._zoom_slider.setFixedWidth(120)
self._zoom_slider.setFixedHeight(20)
self._zoom_slider.setTickPosition(QSlider.TicksBelow)
self._zoom_slider.setSliderPosition(500)
self._zoom_slider.setMinimum(1)
self._zoom_slider.setMaximum(1000)
self._zoom_slider.setTickInterval(250)
self._zoom_slider.valueChanged.connect(self.on_zoom)
self._pan_slider = QSlider(Qt.Horizontal)
self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._pan_slider.setFixedHeight(20)
self._pan_slider.setTickPosition(QSlider.TicksBelow)
self._pan_slider.setSliderPosition(0)
self._pan_slider.setMinimum(1)
self._pan_slider.setMaximum(1000)
self._pan_slider.setTickInterval(250)
self._pan_slider.valueChanged.connect(self.on_pan)
pl = QLabel("horiz. pos.:")
pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
pl.setMaximumWidth(150)
zl = QLabel("zoom:")
zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
zl.setMaximumWidth(75)
plot_controls.layout().addWidget(pl)
plot_controls.layout().addWidget(self._pan_slider)
plot_controls.layout().addWidget(zl)
plot_controls.layout().addWidget(QLabel("+"))
plot_controls.layout().addWidget(self._zoom_slider)
plot_controls.layout().addWidget(QLabel("-"))
self.layout().addWidget(plot_controls)
def on_close(self):
self.close_signal.emit()
def on_zoom(self, new_position):
if self.plotter is None:
return
if self._software_slide:
self._software_slide = False
return
self.plotter.horizontal_pan_to_position(self._pan_slider.sliderPosition()/1000, self._zoom_slider.sliderPosition()/1000)
def on_pan(self, new_position):
# self._pan_slider.setEnabled(False)
if self._software_slide:
self._software_slide = False
return
self.plotter.horizontal_pan_to_position(new_position/1000., self._zoom_slider.sliderPosition()/1000.)
def on_view_changed(self):
self._pan_slider.setEnabled(self.plotter.can_pan_horizontally)
def plot_data_array(self, item):
try:
self._data_view = DataView(item, self._file_handler)
except ValueError as e:
print("ping!")
communicator.plot_error.emit("error in plotscreen.plot %s" % e)
return
if self._data_view is None:
return
if item.suggested_plotter == PlotterTypes.LinePlotter:
self._zoom_slider.setEnabled(True)
self._pan_slider.setEnabled(True)
zoom_slider_position = np.round(1000 / (self._data_view.full_shape[item.best_xdim] / 50000))
self._software_slide = True
self._zoom_slider.setSliderPosition(zoom_slider_position)
self.plotter = LinePlotter(self._file_handler, item, self._data_view)
self.plotter.view_changed.connect(self.on_view_changed)
self._container.set_plotter(self.plotter)
self.plotter.plot(zoomlevel=zoom_slider_position/1000.)
self._software_slide = False
elif item.suggested_plotter == PlotterTypes.ImagePlotter:
self._zoom_slider.setEnabled(False)
self._pan_slider.setEnabled(False)
self.plotter = ImagePlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter)
self.plotter.plot()
elif item.suggested_plotter == PlotterTypes.EventPlotter:
self._zoom_slider.setEnabled(False)
self._pan_slider.setEnabled(False)
self.plotter = EventPlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter)
self.plotter.plot()
elif item.suggested_plotter == PlotterTypes.CategoryPlotter:
self._zoom_slider.setEnabled(False)
self._pan_slider.setEnabled(False)
self.plotter = CategoryPlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter)
self.plotter.plot()
else:
self._container.set_plotter(None)
def plot_tag(self, item):
pass
def plot(self, item):
if item.entity_type == NodeType.DataArray:
self.plot_data_array(item)
elif item.entity_type == NodeType.MultiTag or item.entity_type == NodeType.Tag:
self.plot_tag(item)
else:
communicator.plot_error.emit("error in plotscreen.plot cannot plot entity of type%s" % item.entity_type)
return