nixview-python/nixview/ui/plotscreen.py

352 lines
11 KiB
Python

from PyQt5.QtWidgets import QHBoxLayout, QPushButton, QVBoxLayout, QWidget
from PyQt5.QtCore import QObject, pyqtSignal, Qt
import matplotlib
matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from nixview.file_handler import FileHandler, DataView
def create_label(entity):
label = ""
if hasattr(entity, "label"):
label += (entity.label if entity.label is not None else "")
if len(label) == 0 and hasattr(entity, "name"):
label += entity.name
if hasattr(entity, "unit") and entity.unit is not None:
label += " [%s]" % entity.unit
return label
class Plotter(object):
def show(self):
plt.show()
class EventPlotter(Plotter):
def __init__(self, data_array, xdim=-1):
self.array = data_array
self.sc = None
self.dim_count = len(data_array.dimensions)
if xdim == -1:
self.xdim = guess_best_xdim(self.array)
elif xdim > 1:
raise ValueError("EventPlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
def plot(self, axis=None):
if axis is None:
self.fig = plt.figure(figsize=[5.5, 2.])
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
self.axis.set_title(self.array.name)
else:
self.fig = axis.figure
self.axis = axis
if len(self.array.dimensions) == 1:
return self.plot_1d()
else:
return None
def plot_1d(self):
data = self.array[:]
xlabel = create_label(self.array.dimensions[self.xdim])
dim = self.array.dimensions[self.xdim]
if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias:
ylabel = create_label(self.array)
else:
ylabel = ""
self.sc = self.axis.scatter(data, np.ones(data.shape))
self.axis.set_ylim([0.5, 1.5])
self.axis.set_yticks([1.])
self.axis.set_yticklabels([])
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
return self.axis
class CategoryPlotter(Plotter):
def __init__(self, data_array, xdim=-1):
self.array = data_array
self.bars = []
if xdim == -1:
self.xdim = guess_best_xdim(self.array)
elif xdim > 2:
raise ValueError("CategoryPlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
def plot(self, axis=None):
if axis is None:
self.fig = plt.figure()
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
self.axis.set_title(self.array.name)
else:
self.fig = axis.figure
self.axis = axis
if len(self.array.dimensions) == 1:
return self.plot_1d()
elif len(self.array.dimensions) == 2:
return self.plot_2d()
else:
return None
def plot_1d(self):
data = self.array[:]
dim = self.array.dimensions[self.xdim]
if dim.dimension_type == nix.DimensionType.Set:
categories = list(dim.labels)
else:
return None
if categories is None:
categories = ["Cat-%i" % i for i in range(len(data))]
ylabel = create_label(self.array)
if len(categories) == 0:
raise ValueError("Cannot plot a bar chart without any labels")
self.bars.append(self.axis.bar(range(1, len(categories)+1), data,
tick_label=categories))
self.axis.set_ylabel(ylabel)
return self.axis
def plot_2d(self):
data = self.array[:]
if self.xdim == 1:
data = data.T
dim = self.array.dimensions[self.xdim]
if dim.dimension_type == nix.DimensionType.Set:
categories = list(dim.labels)
if len(categories) == 0:
categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])]
dim = self.array.dimensions[1-self.xdim]
if dim.dimension_type == nix.DimensionType.Set:
series_names = list(dim.labels)
if len(series_names) == 0:
series_names = ["Series-%i" % i
for i in range(data.shape[1-self.xdim])]
bar_width = 1/data.shape[1] * 0.75
for i in range(data.shape[1]):
x_values = np.arange(data.shape[0]) + i * bar_width
self.bars.append(self.axis.bar(x_values, data[:, i],
width=bar_width,
align="center")[0])
self.axis.set_xticks(np.arange(data.shape[0]) +
data.shape[1] * bar_width/2)
self.axis.set_xticklabels(categories)
self.axis.legend(self.bars, series_names, loc=1)
return self.axis
class ImagePlotter(Plotter):
def __init__(self, data_array, xdim=-1):
self.array = data_array
self.image = None
def plot(self, axis=None):
dim_count = len(self.array.dimensions)
if axis is None:
self.fig = plt.figure()
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
self.axis.set_title(self.array.name)
else:
self.fig = axis.figure
self.axis = axis
if dim_count == 2:
return self.plot_2d()
elif dim_count == 3:
return self.plot_3d()
else:
return None
def plot_2d(self):
data = self.array[:]
x = self.array.dimensions[0].axis(data.shape[0])
y = self.array.dimensions[1].axis(data.shape[1])
xlabel = create_label(self.array.dimensions[0])
ylabel = create_label(self.array.dimensions[1])
self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]])
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.axis.set
return self.axis
def plot_3d(self):
if self.array.shape[2] > 3:
print("cannot plot 3d data with more than 3 channels "
"in the third dim")
return None
return self.plot_2d()
class LinePlotter(Plotter):
def __init__(self, data_array, xdim=-1):
self.array = data_array
self.lines = []
self.dim_count = len(data_array.dimensions)
if xdim == -1:
self.xdim = guess_best_xdim(self.array)
elif xdim > 2:
raise ValueError("LinePlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
self.fig = None
self.axis = None
def plot(self, axis=None, maxpoints=100000):
self.maxpoints = maxpoints
if axis is None:
self.fig = plt.figure()
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
self.axis.set_title(self.array.name)
self.__add_slider()
else:
self.axis = axis
dim_count = len(self.array.dimensions)
if dim_count > 2:
return
if dim_count == 1:
return self.plot_array_1d()
else:
return self.plot_array_2d()
def __add_slider(self):
steps = self.array.shape[self.xdim] / self.maxpoints
slider_ax = self.fig.add_axes([0.15, 0.025, 0.8, 0.025])
self.slider = Slider(slider_ax, 'Slider', 1., steps, valinit=1.,
valstep=0.25)
self.slider.on_changed(self.__update)
def __update(self, val):
if len(self.lines) > 0:
minimum = val * self.maxpoints - self.maxpoints
start = minimum if minimum > 0 else 0
end = val * self.maxpoints
self.__draw(start, end)
self.fig.canvas.draw_idle()
def __draw(self, start, end):
if self.dim_count == 1:
self.__draw_1d(start, end)
else:
self.__draw_2d(start, end)
def __draw_1d(self, start, end):
if start < 0:
start = 0
if end > self.array.shape[self.xdim]:
end = self.array.shape[self.xdim]
y = self.array[int(start):int(end)]
dim = self.array.dimensions[self.xdim]
x = np.asarray(dim.axis(len(y), int(start)))
if len(self.lines) == 0:
l, = self.axis.plot(x, y, label=self.array.name)
self.lines.append(l)
else:
self.lines[0].set_ydata(y)
self.lines[0].set_xdata(x)
self.axis.set_xlim([x[0], x[-1]])
def __draw_2d(self, start, end):
if start < 0:
start = 0
if end > self.array.shape[self.xdim]:
end = self.array.shape[self.xdim]
x_dimension = self.array.dimensions[self.xdim]
x = np.asarray(x_dimension.axis(int(end-start), start))
y_dimension = self.array.dimensions[1-self.xdim]
labels = y_dimension.labels
if len(labels) == 0:
labels = list(map(str, range(self.array.shape[1-self.xdim])))
for i, l in enumerate(labels):
if (self.xdim == 0):
y = self.array[int(start):int(end), i]
else:
y = self.array[i, int(start):int(end)]
if len(self.lines) <= i:
ll, = self.axis.plot(x, y, label=l)
self.lines.append(ll)
else:
self.lines[i].set_ydata(y)
self.lines[i].set_xdata(x)
self.axis.set_xlim([x[0], x[-1]])
def plot_array_1d(self):
self.__draw_1d(0, self.maxpoints)
xlabel = create_label(self.array.dimensions[self.xdim])
ylabel = create_label(self.array)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
return self.axis
def plot_array_2d(self):
self.__draw_2d(0, self.maxpoints)
xlabel = create_label(self.array.dimensions[self.xdim])
ylabel = create_label(self.array)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.axis.legend(loc=1)
return self.axis
class MplCanvas(FigureCanvasQTAgg):
def __init__(self, parent=None, width=5, height=4, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi)
self.axes = fig.add_subplot(111)
super(MplCanvas, self).__init__(fig)
class PlotScreen(QWidget):
close_signal = pyqtSignal()
def __init__(self, parent=None) -> None:
super().__init__(parent=parent)
self._file_handler = FileHandler()
sc = MplCanvas(self, width=5, height=4, dpi=100)
sc.axes.plot([0,1,2,3,4], [10,1,20,3,40])
self.setLayout(QVBoxLayout())
self.layout().addWidget(sc)
close_btn = QPushButton("close")
close_btn.clicked.connect(self.on_close)
self.layout().addWidget(close_btn)
self._data_view = None
def on_close(self):
self.close_signal.emit()
def plot(self, item):
print("plot!", item)
print(item.entity_type, item.shape)
self._data_view = DataView(item, self._file_handler)
self._data_view.request_more()
print(self._data_view)