nixview-python/nixview/ui/plotscreen.py

597 lines
22 KiB
Python

from nixview.util.enums import PlotterTypes
from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget
from PyQt5.QtCore import pyqtSignal, Qt
import matplotlib
matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import numpy as np
try:
import matplotlib.pyplot as plt
except ImportError as e:
print("cannot import matplotlib, headless mode?", e)
from matplotlib.widgets import Slider
from IPython import embed
from nixview.util.file_handler import FileHandler
from nixview.util.dataview import DataView
from nixview.communicator import communicator
def create_label(item):
label = ""
if hasattr(item, "label"):
label += (item.label if item.label is not None else "")
if len(label) == 0 and hasattr(item, "name"):
label += item.name
if hasattr(item, "unit") and item.unit is not None:
label += " [%s]" % item.unit
return label
class MplCanvas(FigureCanvas):
""" MplCanvas extends FigureCanvasQtAgg Matplotlib backend.
Args:
FigureCanvas ([type]): [description]
"""
view_changed = pyqtSignal()
def __init__(self, parent=None, width=5, height=4, dpi=100):
figure = Figure(figsize=(width, height), dpi=dpi)
self.axis = figure.add_subplot(111)
super(MplCanvas, self).__init__(figure)
self._figure = figure
self._figure.canvas.mpl_connect('figure_enter_event', self.on_enter_figure)
self._figure.canvas.mpl_connect('figure_leave_event', self.on_leave_figure)
self._figure.canvas.mpl_connect('axes_enter_event', self.on_enter_axes)
self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes)
self._figure.canvas.mpl_connect('pick_event', self.on_pick)
def on_enter_figure(self, event):
# print('enter_figure', event.canvas.figure)
# event.canvas.figure.patch.set_facecolor('red')
# event.canvas.draw()
pass
def on_leave_figure(self, event):
# print('leave_figure', event.canvas.figure)
# event.canvas.figure.patch.set_facecolor('grey')
# event.canvas.draw()
pass
def on_enter_axes(self, event):
# print('enter_axes', event.inaxes)
# event.inaxes.patch.set_facecolor('yellow')
# event.canvas.draw()
pass
def on_leave_axes(self, event):
# print('leave_axes', event.inaxes)
# event.inaxes.patch.set_facecolor('white')
# event.canvas.draw()
pass
def on_pick(self, event):
line = event.artist
print(line.label)
xdata, ydata = line.get_data()
ind = event.ind
print('on pick line:', np.array([xdata[ind], ydata[ind]]).T)
#def clear(self):
# self.clear()
#@property
#def enter_figure(self):
# return self._fig
class Plotter(MplCanvas):
""" Abstract class for visual display of data (plotting)
Inheriting classes need to implement the current_view and is_full_view methods
Plotter extends MplCanvas.
"""
def __init__(self, file_handler, item, data_view, parent=None) -> None:
super().__init__(parent=parent)
self._file_handler = file_handler
self._item = item
self._dataview = data_view
@property
def is_full_view(self):
raise NotImplementedError("is_full_view is not implemented on the current plotter")
@property
def can_pan_horizontally(self):
raise NotImplementedError("can_pan_left is not implemented on the current plotter")
class EventPlotter(Plotter):
def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
super().__init__(file_handler, item, data_view, parent)
self.dim_count = len(self._dataview.full_shape)
if xdim == -1:
self.xdim = self._item.best_xdim
else:
self.xdim = xdim
self._abs_xmin = 0
self._abs_xmax = self._dataview.full_shape[self.xdim]
self._view_xmin = 0
self._view_xmax = 0
self._zoom_level = 0
self._segment_length = 0
self.sc = None
@property
def is_full_view(self):
full = self._data_xmin == self._view_xmin and self._data_xmax == self._view_xmax
return full
@property
def can_pan_horizontally(self):
return self.can_pan_left or self.can_pan_right
@property
def can_pan_left(self):
return self._view_xmin > self._abs_xmin
@property
def can_pan_right(self):
return self._view_xmax < self._abs_xmax
@property
def horizontal_pan_position(self):
return self._view_xmax/self._abs_xmax
def horizontal_pan_to_position(self, new_position, zoomlevel):
new_xmax = int(np.min([np.ceil(new_position * self._abs_xmax), self._abs_xmax]))
segment_length = zoomlevel * self._abs_xmax
start = np.max([0, new_xmax - segment_length])
while not self._dataview.fully_loaded and new_xmax < self._dataview.current_shape[self.xdim]:
self._dataview.request_more()
self.plot(start, zoomlevel)
def plot(self, start=0, zoomlevel=1.0):
if zoomlevel > 1:
zoomlevel = 1.0
self._segment_length = zoomlevel * self._abs_xmax
self._zoom_level = zoomlevel
if self.dim_count == 1:
return self.plot_1d(start)
else:
return None # FIXME 2D events?
def plot_1d(self, start=0, zoomlevel=1.0):
if start < 0:
start = 0
end = start + self._segment_length
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
y_values = self._dataview._buffer[int(start):int(end)]
x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y_values), int(start))
data = self._dataview._buffer[:]
dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
xlabel = create_label(dimensions[self.xdim])
#dim = dimensions[self.xdim]
#if "Range"dim.type == dRange and not dim.is_alias:
# ylabel = create_label(self.array)
#else:
# ylabel = ""
if self.sc is None:
label = self._item.name
self.sc = self.axis.scatter(data, np.ones(data.shape))
self.sc.set_pickradius(5)
else:
self.sc.set_data(data, np.ones(data.shape))
self.figure.canvas.draw_idle()
# self.axis.set_ylim([np.min(y_values), np.max(y_values)])
self.axis.set_xlim([x_values[0], x_values[-1]])
self.axis.set_ylim([0.5, 1.5])
self.axis.set_yticks([1.])
self.axis.set_yticklabels([])
self.axis.set_xlabel(xlabel)
# self.axis.set_ylabel(ylabel)
return self.axis
class CategoryPlotter(Plotter):
def __init__(self, file_handler, item, data_view, parent=None):
super().__init__(file_handler, item, data_view, parent)
self.dim_count = len(self._dataview.full_shape)
self._xdim = self._item.best_xdim
self.bars = []
def plot(self):
if len(self._dataview.full_shape) == 1:
return self.plot_1d()
elif len(self._dataview.full_shape) == 2:
return self.plot_2d()
else:
return None
def plot_1d(self):
categories = None
ylabel = create_label(self._item)
categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim])
if categories is None or len(categories) == 0:
categories = ["Cat-%i" % i for i in range(len(categories))]
ylabel = create_label(self._item)
self.bars.append(self.axis.bar(range(1, len(categories) + 1), self._dataview._buffer, tick_label=categories))
self.axis.set_ylabel(ylabel)
return self.axis
def plot_2d(self):
categories = None
ylabel = create_label(self._item)
data = self._dataview._buffer
if self._item.best_xdim == 1:
data = data.T
categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim])
if categories is None or len(categories) == 0:
categories = ["Cat-%i" % i for i in range(self._item.shape[self._item.best_xdim])]
series = self._file_handler.request_axis(self._item.block_id, self._item.id, 1 - self._item.best_xdim, self._item.shape[1 - self._item.best_xdim])
if len(series) == 0:
series = ["Series-%i" % i for i in range(self._item.shape[1 - self._item.best_xdim])]
ylabel = create_label(self._item)
bar_width = 1/len(series) * 0.75
for i in range(len(series)):
x_values = np.arange(len(categories)) + i * bar_width
self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0])
self.axis.set_xticks(np.arange(len(categories)) + len(series) * bar_width/2)
self.axis.set_xticklabels(categories)
self.axis.legend(self.bars, series, loc=1)
self.axis.set_ylabel(ylabel)
class ImagePlotter(Plotter):
def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
super().__init__(file_handler, item, data_view, parent)
self._dim_count = len(self._dataview.full_shape)
if self._dim_count > 3:
raise ValueError("ImagePlotter cannot plot data with more than 3 dimensions! Shape of %s is %s " %(self._item.name, str(self._dataview.full_shape)))
self.image = None
def plot(self):
while not self._dataview.fully_loaded:
self._dataview.request_more()
self.axis.set_title(self._item.name)
if self._dim_count == 2:
return self.plot_2d()
elif self._dim_count == 3:
return self.plot_3d()
def plot_2d(self):
data = self._dataview._buffer[:]
if len(data.shape) == 3:
data = data.astype("uint8")
x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, self._dataview.full_shape[0], 0)
y_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 1, self._dataview.full_shape[1], 0)
dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
xlabel = create_label(dimensions[0])
ylabel = create_label(dimensions[1])
self.image = self.axis.imshow(data, extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]])
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
return self.axis
def plot_3d(self):
if self._dataview.full_shape[2] > 3:
print("cannot plot 3d data with more than 3 channels "
"in the third dim")
return None
return self.plot_2d()
class LinePlotter(Plotter):
""" LinePlotter extends and implements the Plotter class. It shows line plot data. Either single or multiple line
Args:
Plotter ([type]): [description]
"""
def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
super().__init__(file_handler, item, data_view, parent)
#self.canvas = FigureCanvas(self.figure)
#self.toolbar = NavigationToolbar(self.canvas, self)
self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
self.lines = []
self.dim_count = len(self._dataview.full_shape)
if xdim == -1:
self.xdim = item.best_xdim
elif xdim > 2:
raise ValueError("LinePlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
self._data_xmin = 0
self._data_xmax = self._dataview.current_shape[self.xdim]
self._abs_xmin = 0
self._abs_xmax = self._dataview.full_shape[self.xdim]
self._view_xmin = 0
self._view_xmax = 0
self._zoom_level = 0
self._segment_length = 0
# self.axis.callbacks.connect('xlim_changed', self.on_xlims_change)
# self.axis.callbacks.connect('ylim_changed', self.on_ylims_change)
@property
def is_full_view(self):
full = self._data_xmin == self._view_xmin and self._data_xmax == self._view_xmax
return full
@property
def can_pan_horizontally(self):
return self.can_pan_left or self.can_pan_right
@property
def can_pan_left(self):
return self._view_xmin > self._abs_xmin
@property
def can_pan_right(self):
return self._view_xmax < self._abs_xmax
@property
def horizontal_pan_position(self):
return self._view_xmax/self._abs_xmax
def horizontal_pan_to_position(self, new_position, zoomlevel):
new_xmax = int(np.min([np.ceil(new_position * self._abs_xmax), self._abs_xmax]))
segment_length = zoomlevel * self._abs_xmax
start = np.max([0, new_xmax - segment_length])
while not self._dataview.fully_loaded and new_xmax < self._dataview.current_shape[self.xdim]:
self._dataview.request_more()
self.plot(start, zoomlevel)
def on_zoom_in(self, new_position):
print("plotter ZOOM In!", new_position)
def on_zoom_out(self, new_position):
print("plotter ZOOM out!", new_position)
def plot(self, start=0, zoomlevel=1.0):
if zoomlevel > 1:
zoomlevel = 1.0
self._segment_length = zoomlevel * self._abs_xmax
self._zoom_level = zoomlevel
if self.dim_count > 2:
return
if self.dim_count == 1:
self.plot_array_1d(start)
else:
self.plot_array_2d(start)
def _update_abs_extremes(self, display_x_min, display_xmax):
if self._data_xmin is None or display_x_min < self._data_xmin:
self._data_xmin = display_x_min
if self._data_xmax is None or display_xmax > self._data_xmax:
self._data_xmax = display_xmax
def _update_current_view(self, current_xmin, current_xmax):
self._view_xmax = current_xmax
self._view_xmin = current_xmin
def __draw_1d(self, start, end):
""" draw the data from start to end index.
Args:
start (int): start index in the data
end (int): end index in the data
"""
if start < 0:
start = 0
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
y_values = self._dataview._buffer[int(start):int(end)]
x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y_values), int(start))
self._update_abs_extremes(start, end)
self._update_current_view(start, end)
if len(self.lines) == 0:
label = self._item.name
l, = self.axis.plot(x_values, y_values, label=label)
l.set_pickradius(5)
self.lines.append(l)
else:
self.lines[-1].set_data(x_values[:len(y_values)], y_values)
self.figure.canvas.draw_idle()
self.axis.set_ylim([np.min(y_values), np.max(y_values)])
self.axis.set_xlim([x_values[0], x_values[-1]])
def __draw_2d(self, start, end):
if start < 0:
start = 0
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start)
line_count = self._dataview.current_shape[1 - self.xdim]
line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0)
self._update_abs_extremes(x[0], x[-1])
for i, l in enumerate(line_labels):
if (self.xdim == 0):
y = self._dataview._buffer[int(start):int(end), i]
else:
y = self._dataview._buffer[i, int(start):int(end)]
if len(self.lines) <= i:
ll, = self.axis.plot(x, y, label=l)
ll.set_pickradius(5)
self.lines.append(ll)
else:
self.lines[i].set_ydata(y)
self.lines[i].set_xdata(x)
self.axis.legend()
self.axis.set_xlim([x[0], x[-1]])
def plot_array_1d(self, start=0):
self.__draw_1d(start, start + self._segment_length)
xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.view_changed.emit()
def plot_array_2d(self, start=0):
self.__draw_2d(start, start + self._segment_length)
xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.axis.legend(loc=1)
self.view_changed.emit()
class PlotContainer(QWidget):
def __init__(self, parent=None) -> None:
super().__init__(parent=parent)
self.setLayout(QVBoxLayout())
self.plotter = None
def set_plotter(self, plotter):
if not self.layout().isEmpty():
self.layout().removeWidget(self.plotter)
self.layout().addWidget(plotter)
self.plotter = plotter
class PlotScreen(QWidget):
close_signal = pyqtSignal()
def __init__(self, parent=None) -> None:
super().__init__(parent=parent)
self._file_handler = FileHandler()
self.setLayout(QVBoxLayout())
self._container = PlotContainer(self)
self.layout().addWidget(self._container)
close_btn = QPushButton("close")
close_btn.clicked.connect(self.on_close)
self._create_plot_controls()
self.layout().addWidget(close_btn)
self._data_view = None
self._software_slide = False
self.plotter = None
def _create_plot_controls(self):
plot_controls = QGroupBox()
plot_controls.setFlat(True)
plot_controls.setLayout(QHBoxLayout())
plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._zoom_slider = QSlider(Qt.Horizontal)
self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
self._zoom_slider.setFixedWidth(120)
self._zoom_slider.setFixedHeight(20)
self._zoom_slider.setTickPosition(QSlider.TicksBelow)
self._zoom_slider.setSliderPosition(500)
self._zoom_slider.setMinimum(1)
self._zoom_slider.setMaximum(1000)
self._zoom_slider.setTickInterval(250)
self._zoom_slider.valueChanged.connect(self.on_zoom)
self._pan_slider = QSlider(Qt.Horizontal)
self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._pan_slider.setFixedHeight(20)
self._pan_slider.setTickPosition(QSlider.TicksBelow)
self._pan_slider.setSliderPosition(0)
self._pan_slider.setMinimum(1)
self._pan_slider.setMaximum(1000)
self._pan_slider.setTickInterval(250)
self._pan_slider.valueChanged.connect(self.on_pan)
pl = QLabel("horiz. pos.:")
pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
pl.setMaximumWidth(150)
zl = QLabel("zoom:")
zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
zl.setMaximumWidth(75)
plot_controls.layout().addWidget(pl)
plot_controls.layout().addWidget(self._pan_slider)
plot_controls.layout().addWidget(zl)
plot_controls.layout().addWidget(QLabel("+"))
plot_controls.layout().addWidget(self._zoom_slider)
plot_controls.layout().addWidget(QLabel("-"))
self.layout().addWidget(plot_controls)
def on_close(self):
self.close_signal.emit()
def on_zoom(self, new_position):
if self.plotter is None:
return
if self._software_slide:
self._software_slide = False
return
self.plotter.horizontal_pan_to_position(self._pan_slider.sliderPosition()/1000, self._zoom_slider.sliderPosition()/1000)
def on_pan(self, new_position):
# self._pan_slider.setEnabled(False)
if self._software_slide:
self._software_slide = False
return
self.plotter.horizontal_pan_to_position(new_position/1000., self._zoom_slider.sliderPosition()/1000.)
def on_view_changed(self):
self._pan_slider.setEnabled(self.plotter.can_pan_horizontally)
def plot(self, item):
try:
self._data_view = DataView(item, self._file_handler)
except ValueError as e:
communicator.plot_error.emit("error in plotscreen.plot %s" % e)
return
if self._data_view is None:
return
if item.suggested_plotter == PlotterTypes.LinePlotter:
self._zoom_slider.setEnabled(True)
self._pan_slider.setEnabled(True)
zoom_slider_position = np.round(1000 / (self._data_view.full_shape[item.best_xdim] / 50000))
self._software_slide = True
self._zoom_slider.setSliderPosition(zoom_slider_position)
self.plotter = LinePlotter(self._file_handler, item, self._data_view)
self.plotter.view_changed.connect(self.on_view_changed)
self._container.set_plotter(self.plotter)
self.plotter.plot(zoomlevel=zoom_slider_position/1000.)
self._software_slide = False
elif item.suggested_plotter == PlotterTypes.ImagePlotter:
self._zoom_slider.setEnabled(False)
self._pan_slider.setEnabled(False)
self.plotter = ImagePlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter)
self.plotter.plot()
elif item.suggested_plotter == PlotterTypes.EventPlotter:
self._zoom_slider.setEnabled(False)
self._pan_slider.setEnabled(False)
self.plotter = EventPlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter)
self.plotter.plot()
elif item.suggested_plotter == PlotterTypes.CategoryPlotter:
self._zoom_slider.setEnabled(False)
self._pan_slider.setEnabled(False)
self.plotter = CategoryPlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter)
self.plotter.plot()
else:
self._container.set_plotter(None)