nixview-python/nixview/ui/plotscreen.py

427 lines
15 KiB
Python

from nixview.data_models.tree_model import PropertyTreeItem
from nixview.util import dataview
from nixview.util.enums import PlotterTypes
from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget
from PyQt5.QtCore import QObject, pyqtSignal, Qt
import matplotlib
matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from IPython import embed
from nixview.util.file_handler import FileHandler
from nixview.util.dataview import DataView
from nixview.communicator import communicator
def create_label(item):
label = ""
if hasattr(item, "label"):
label += (item.label if item.label is not None else "")
if len(label) == 0 and hasattr(item, "name"):
label += item.name
if hasattr(item, "unit") and item.unit is not None:
label += " [%s]" % item.unit
return label
class MplCanvas(FigureCanvasQTAgg):
view_changed = pyqtSignal()
def __init__(self, parent=None, width=5, height=4, dpi=100):
figure = Figure(figsize=(width, height), dpi=dpi)
self.axis = figure.add_subplot(111)
super(MplCanvas, self).__init__(figure)
self._figure = figure
#def clear(self):
# self.clear()
#@property
#def figure(self):
# return self._fig
class Plotter(MplCanvas):
def __init__(self, file_handler, item, parent=None) -> None:
super().__init__(parent=parent)
self._file_handler = file_handler
self._item = item
def show(self):
plt.show()
class EventPlotter(Plotter):
def __init__(self, data_array, xdim=-1):
self.array = data_array
self.sc = None
self.dim_count = len(data_array.dimensions)
if xdim == -1:
self.xdim = guess_best_xdim(self.array)
elif xdim > 1:
raise ValueError("EventPlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
def plot(self, axis=None):
if axis is None:
self.fig = plt.figure(figsize=[5.5, 2.])
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
self.axis.set_title(self.array.name)
else:
self.fig = axis.figure
self.axis = axis
if len(self.array.dimensions) == 1:
return self.plot_1d()
else:
return None
def plot_1d(self):
data = self.array[:]
xlabel = create_label(self.array.dimensions[self.xdim])
dim = self.array.dimensions[self.xdim]
if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias:
ylabel = create_label(self.array)
else:
ylabel = ""
self.sc = self.axis.scatter(data, np.ones(data.shape))
self.axis.set_ylim([0.5, 1.5])
self.axis.set_yticks([1.])
self.axis.set_yticklabels([])
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
return self.axis
class CategoryPlotter(Plotter):
def __init__(self, data_array, xdim=-1):
self.array = data_array
self.bars = []
if xdim == -1:
self.xdim = guess_best_xdim(self.array)
elif xdim > 2:
raise ValueError("CategoryPlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
def plot(self, axis=None):
if axis is None:
self.fig = plt.figure()
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
self.axis.set_title(self.array.name)
else:
self.fig = axis.figure
self.axis = axis
if len(self.array.dimensions) == 1:
return self.plot_1d()
elif len(self.array.dimensions) == 2:
return self.plot_2d()
else:
return None
def plot_1d(self):
data = self.array[:]
dim = self.array.dimensions[self.xdim]
if dim.dimension_type == nix.DimensionType.Set:
categories = list(dim.labels)
else:
return None
if categories is None:
categories = ["Cat-%i" % i for i in range(len(data))]
ylabel = create_label(self.array)
if len(categories) == 0:
raise ValueError("Cannot plot a bar chart without any labels")
self.bars.append(self.axis.bar(range(1, len(categories)+1), data,
tick_label=categories))
self.axis.set_ylabel(ylabel)
return self.axis
def plot_2d(self):
data = self.array[:]
if self.xdim == 1:
data = data.T
dim = self.array.dimensions[self.xdim]
if dim.dimension_type == nix.DimensionType.Set:
categories = list(dim.labels)
if len(categories) == 0:
categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])]
dim = self.array.dimensions[1-self.xdim]
if dim.dimension_type == nix.DimensionType.Set:
series_names = list(dim.labels)
if len(series_names) == 0:
series_names = ["Series-%i" % i
for i in range(data.shape[1-self.xdim])]
bar_width = 1/data.shape[1] * 0.75
for i in range(data.shape[1]):
x_values = np.arange(data.shape[0]) + i * bar_width
self.bars.append(self.axis.bar(x_values, data[:, i],
width=bar_width,
align="center")[0])
self.axis.set_xticks(np.arange(data.shape[0]) +
data.shape[1] * bar_width/2)
self.axis.set_xticklabels(categories)
self.axis.legend(self.bars, series_names, loc=1)
return self.axis
class ImagePlotter(Plotter):
def __init__(self, data_array, xdim=-1):
self.array = data_array
self.image = None
def plot(self, axis=None):
dim_count = len(self.array.dimensions)
if axis is None:
self.fig = plt.figure()
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
self.axis.set_title(self.array.name)
else:
self.fig = axis.figure
self.axis = axis
if dim_count == 2:
return self.plot_2d()
elif dim_count == 3:
return self.plot_3d()
else:
return None
def plot_2d(self):
data = self.array[:]
x = self.array.dimensions[0].axis(data.shape[0])
y = self.array.dimensions[1].axis(data.shape[1])
xlabel = create_label(self.array.dimensions[0])
ylabel = create_label(self.array.dimensions[1])
self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]])
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.axis.set
return self.axis
def plot_3d(self):
if self.array.shape[2] > 3:
print("cannot plot 3d data with more than 3 channels "
"in the third dim")
return None
return self.plot_2d()
class LinePlotter(Plotter):
def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
super().__init__(file_handler, item, parent)
self._dataview = data_view
self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
self.lines = []
self.dim_count = len(self._dataview.full_shape)
if xdim == -1:
self.xdim = item.best_xdim
elif xdim > 2:
raise ValueError("LinePlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
def plot(self, maxpoints=100000):
self.maxpoints = maxpoints
if self.dim_count > 2:
return
if self.dim_count == 1:
self.plot_array_1d()
else:
self.plot_array_2d()
def __add_slider(self):
steps = self.array.shape[self.xdim] / self.maxpoints
slider_ax = self.fig.add_axes([0.15, 0.025, 0.8, 0.025])
self.slider = Slider(slider_ax, 'Slider', 1., steps, valinit=1.,
valstep=0.25)
self.slider.on_changed(self.__update)
def __update(self, val):
if len(self.lines) > 0:
minimum = val * self.maxpoints - self.maxpoints
start = minimum if minimum > 0 else 0
end = val * self.maxpoints
self.__draw(start, end)
self.fig.canvas.draw_idle()
def __draw(self, start, end):
if self.dim_count == 1:
self.__draw_1d(start, end)
else:
self.__draw_2d(start, end)
def __draw_1d(self, start, end):
if start < 0:
start = 0
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
y = self._dataview._buffer[int(start):int(end)]
x = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y), start)
if len(self.lines) == 0:
l, = self.axes.plot(x, y, label=self._item.name)
self.lines.append(l)
else:
self.lines[0].set_ydata(y)
self.lines[0].set_xdata(x)
self.axes.set_xlim([x[0], x[-1]])
def __draw_2d(self, start, end):
if start < 0:
start = 0
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start)
line_count = self._dataview.current_shape[1 - self.xdim]
line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0)
for i, l in enumerate(line_labels):
if (self.xdim == 0):
y = self._dataview._buffer[int(start):int(end), i]
else:
y = self._dataview._buffer[i, int(start):int(end)]
if len(self.lines) <= i:
ll, = self.axis.plot(x, y, label=l)
self.lines.append(ll)
else:
self.lines[i].set_ydata(y)
self.lines[i].set_xdata(x)
self.axis.set_xlim([x[0], x[-1]])
def plot_array_1d(self):
self.__draw_1d(0, self.maxpoints)
xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.view_changed.emit()
def plot_array_2d(self):
self.__draw_2d(0, self.maxpoints)
xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.axis.legend(loc=1)
self.view_changed.emit()
class PlotContainer(QWidget):
def __init__(self, parent=None) -> None:
super().__init__(parent=parent)
self.setLayout(QVBoxLayout())
self.plotter = None
def set_plotter(self, plotter):
if not self.layout().isEmpty():
self.layout().removeWidget(self.plotter)
self.layout().addWidget(plotter)
self.plotter = plotter
class PlotScreen(QWidget):
close_signal = pyqtSignal()
def __init__(self, parent=None) -> None:
super().__init__(parent=parent)
self._file_handler = FileHandler()
self.setLayout(QVBoxLayout())
self._container = PlotContainer(self)
self.layout().addWidget(self._container)
close_btn = QPushButton("close")
close_btn.clicked.connect(self.on_close)
self._create_plot_controls()
self.layout().addWidget(close_btn)
self._data_view = None
def _create_plot_controls(self):
plot_controls = QGroupBox()
plot_controls.setFlat(True)
plot_controls.setLayout(QHBoxLayout())
plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._zoom_slider = QSlider(Qt.Horizontal)
self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
self._zoom_slider.setFixedWidth(120)
self._zoom_slider.setFixedHeight(20)
self._zoom_slider.setTickPosition(QSlider.TicksBelow)
self._zoom_slider.setSliderPosition(50)
self._zoom_slider.setMinimum(0)
self._zoom_slider.setMaximum(100)
self._zoom_slider.setTickInterval(25)
self._zoom_slider.valueChanged.connect(self.on_zoom)
self._pan_slider = QSlider(Qt.Horizontal)
self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._pan_slider.setFixedHeight(20)
self._pan_slider.setTickPosition(QSlider.TicksBelow)
self._pan_slider.setSliderPosition(0)
self._pan_slider.setMinimum(0)
self._pan_slider.setMaximum(100)
self._pan_slider.setTickInterval(25)
self._pan_slider.valueChanged.connect(self.on_pan)
pl = QLabel("horiz. pos.:")
pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
pl.setMaximumWidth(150)
zl = QLabel("zoom:")
zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
zl.setMaximumWidth(75)
plot_controls.layout().addWidget(pl)
plot_controls.layout().addWidget(self._pan_slider)
plot_controls.layout().addWidget(zl)
plot_controls.layout().addWidget(QLabel("+"))
plot_controls.layout().addWidget(self._zoom_slider)
plot_controls.layout().addWidget(QLabel("-"))
self.layout().addWidget(plot_controls)
def on_close(self):
self.close_signal.emit()
def on_zoom(self, new_position):
print("zoom", new_position)
def on_pan(self, new_position):
print("pan", new_position)
def on_view_changed(self):
print("view changed!")
def plot(self, item):
try:
self._data_view = DataView(item, self._file_handler)
except ValueError as e:
communicator.plot_error.emit("error in plotscreen.plot %s" % e)
return
if self._data_view is None:
return
while not self._data_view.fully_loaded:
self._data_view.request_more() # TODO this is just a test, needs to be removed
if item.suggested_plotter == PlotterTypes.LinePlotter:
self.plotter = LinePlotter(self._file_handler, item, self._data_view)
self.plotter.view_changed.connect(self.on_view_changed)
self._container.set_plotter(self.plotter)
self.plotter.plot()