352 lines
11 KiB
Python
352 lines
11 KiB
Python
from PyQt5.QtWidgets import QHBoxLayout, QPushButton, QVBoxLayout, QWidget
|
|
from PyQt5.QtCore import QObject, pyqtSignal, Qt
|
|
import matplotlib
|
|
matplotlib.use('Qt5Agg')
|
|
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
|
|
from matplotlib.figure import Figure
|
|
import nixio as nix
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.widgets import Slider
|
|
|
|
from nixview.file_handler import FileHandler, DataView
|
|
|
|
|
|
|
|
|
|
def create_label(entity):
|
|
label = ""
|
|
if hasattr(entity, "label"):
|
|
label += (entity.label if entity.label is not None else "")
|
|
if len(label) == 0 and hasattr(entity, "name"):
|
|
label += entity.name
|
|
if hasattr(entity, "unit") and entity.unit is not None:
|
|
label += " [%s]" % entity.unit
|
|
return label
|
|
|
|
|
|
class Plotter(object):
|
|
|
|
def show(self):
|
|
plt.show()
|
|
|
|
|
|
class EventPlotter(Plotter):
|
|
|
|
def __init__(self, data_array, xdim=-1):
|
|
self.array = data_array
|
|
self.sc = None
|
|
self.dim_count = len(data_array.dimensions)
|
|
if xdim == -1:
|
|
self.xdim = guess_best_xdim(self.array)
|
|
elif xdim > 1:
|
|
raise ValueError("EventPlotter: xdim is larger than 2! "
|
|
"Cannot plot that kind of data")
|
|
else:
|
|
self.xdim = xdim
|
|
|
|
def plot(self, axis=None):
|
|
if axis is None:
|
|
self.fig = plt.figure(figsize=[5.5, 2.])
|
|
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
|
|
self.axis.set_title(self.array.name)
|
|
else:
|
|
self.fig = axis.figure
|
|
self.axis = axis
|
|
if len(self.array.dimensions) == 1:
|
|
return self.plot_1d()
|
|
else:
|
|
return None
|
|
|
|
def plot_1d(self):
|
|
data = self.array[:]
|
|
xlabel = create_label(self.array.dimensions[self.xdim])
|
|
dim = self.array.dimensions[self.xdim]
|
|
if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias:
|
|
ylabel = create_label(self.array)
|
|
else:
|
|
ylabel = ""
|
|
self.sc = self.axis.scatter(data, np.ones(data.shape))
|
|
self.axis.set_ylim([0.5, 1.5])
|
|
self.axis.set_yticks([1.])
|
|
self.axis.set_yticklabels([])
|
|
self.axis.set_xlabel(xlabel)
|
|
self.axis.set_ylabel(ylabel)
|
|
return self.axis
|
|
|
|
|
|
class CategoryPlotter(Plotter):
|
|
|
|
def __init__(self, data_array, xdim=-1):
|
|
self.array = data_array
|
|
self.bars = []
|
|
if xdim == -1:
|
|
self.xdim = guess_best_xdim(self.array)
|
|
elif xdim > 2:
|
|
raise ValueError("CategoryPlotter: xdim is larger than 2! "
|
|
"Cannot plot that kind of data")
|
|
else:
|
|
self.xdim = xdim
|
|
|
|
def plot(self, axis=None):
|
|
if axis is None:
|
|
self.fig = plt.figure()
|
|
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
|
|
self.axis.set_title(self.array.name)
|
|
else:
|
|
self.fig = axis.figure
|
|
self.axis = axis
|
|
if len(self.array.dimensions) == 1:
|
|
return self.plot_1d()
|
|
elif len(self.array.dimensions) == 2:
|
|
return self.plot_2d()
|
|
else:
|
|
return None
|
|
|
|
def plot_1d(self):
|
|
data = self.array[:]
|
|
dim = self.array.dimensions[self.xdim]
|
|
if dim.dimension_type == nix.DimensionType.Set:
|
|
categories = list(dim.labels)
|
|
else:
|
|
return None
|
|
if categories is None:
|
|
categories = ["Cat-%i" % i for i in range(len(data))]
|
|
ylabel = create_label(self.array)
|
|
if len(categories) == 0:
|
|
raise ValueError("Cannot plot a bar chart without any labels")
|
|
self.bars.append(self.axis.bar(range(1, len(categories)+1), data,
|
|
tick_label=categories))
|
|
self.axis.set_ylabel(ylabel)
|
|
return self.axis
|
|
|
|
def plot_2d(self):
|
|
data = self.array[:]
|
|
if self.xdim == 1:
|
|
data = data.T
|
|
|
|
dim = self.array.dimensions[self.xdim]
|
|
if dim.dimension_type == nix.DimensionType.Set:
|
|
categories = list(dim.labels)
|
|
if len(categories) == 0:
|
|
categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])]
|
|
|
|
dim = self.array.dimensions[1-self.xdim]
|
|
if dim.dimension_type == nix.DimensionType.Set:
|
|
series_names = list(dim.labels)
|
|
if len(series_names) == 0:
|
|
series_names = ["Series-%i" % i
|
|
for i in range(data.shape[1-self.xdim])]
|
|
|
|
bar_width = 1/data.shape[1] * 0.75
|
|
for i in range(data.shape[1]):
|
|
x_values = np.arange(data.shape[0]) + i * bar_width
|
|
self.bars.append(self.axis.bar(x_values, data[:, i],
|
|
width=bar_width,
|
|
align="center")[0])
|
|
self.axis.set_xticks(np.arange(data.shape[0]) +
|
|
data.shape[1] * bar_width/2)
|
|
self.axis.set_xticklabels(categories)
|
|
self.axis.legend(self.bars, series_names, loc=1)
|
|
return self.axis
|
|
|
|
|
|
class ImagePlotter(Plotter):
|
|
|
|
def __init__(self, data_array, xdim=-1):
|
|
self.array = data_array
|
|
self.image = None
|
|
|
|
def plot(self, axis=None):
|
|
dim_count = len(self.array.dimensions)
|
|
if axis is None:
|
|
self.fig = plt.figure()
|
|
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
|
|
self.axis.set_title(self.array.name)
|
|
else:
|
|
self.fig = axis.figure
|
|
self.axis = axis
|
|
if dim_count == 2:
|
|
return self.plot_2d()
|
|
elif dim_count == 3:
|
|
return self.plot_3d()
|
|
else:
|
|
return None
|
|
|
|
def plot_2d(self):
|
|
data = self.array[:]
|
|
x = self.array.dimensions[0].axis(data.shape[0])
|
|
y = self.array.dimensions[1].axis(data.shape[1])
|
|
xlabel = create_label(self.array.dimensions[0])
|
|
ylabel = create_label(self.array.dimensions[1])
|
|
self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]])
|
|
self.axis.set_xlabel(xlabel)
|
|
self.axis.set_ylabel(ylabel)
|
|
self.axis.set
|
|
return self.axis
|
|
|
|
def plot_3d(self):
|
|
if self.array.shape[2] > 3:
|
|
print("cannot plot 3d data with more than 3 channels "
|
|
"in the third dim")
|
|
return None
|
|
return self.plot_2d()
|
|
|
|
|
|
class LinePlotter(Plotter):
|
|
|
|
def __init__(self, data_array, xdim=-1):
|
|
self.array = data_array
|
|
self.lines = []
|
|
self.dim_count = len(data_array.dimensions)
|
|
if xdim == -1:
|
|
self.xdim = guess_best_xdim(self.array)
|
|
elif xdim > 2:
|
|
raise ValueError("LinePlotter: xdim is larger than 2! "
|
|
"Cannot plot that kind of data")
|
|
else:
|
|
self.xdim = xdim
|
|
self.fig = None
|
|
self.axis = None
|
|
|
|
def plot(self, axis=None, maxpoints=100000):
|
|
self.maxpoints = maxpoints
|
|
if axis is None:
|
|
self.fig = plt.figure()
|
|
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
|
|
self.axis.set_title(self.array.name)
|
|
self.__add_slider()
|
|
else:
|
|
self.axis = axis
|
|
|
|
dim_count = len(self.array.dimensions)
|
|
if dim_count > 2:
|
|
return
|
|
if dim_count == 1:
|
|
return self.plot_array_1d()
|
|
else:
|
|
return self.plot_array_2d()
|
|
|
|
def __add_slider(self):
|
|
steps = self.array.shape[self.xdim] / self.maxpoints
|
|
slider_ax = self.fig.add_axes([0.15, 0.025, 0.8, 0.025])
|
|
self.slider = Slider(slider_ax, 'Slider', 1., steps, valinit=1.,
|
|
valstep=0.25)
|
|
self.slider.on_changed(self.__update)
|
|
|
|
def __update(self, val):
|
|
if len(self.lines) > 0:
|
|
minimum = val * self.maxpoints - self.maxpoints
|
|
start = minimum if minimum > 0 else 0
|
|
end = val * self.maxpoints
|
|
self.__draw(start, end)
|
|
self.fig.canvas.draw_idle()
|
|
|
|
def __draw(self, start, end):
|
|
if self.dim_count == 1:
|
|
self.__draw_1d(start, end)
|
|
else:
|
|
self.__draw_2d(start, end)
|
|
|
|
def __draw_1d(self, start, end):
|
|
if start < 0:
|
|
start = 0
|
|
if end > self.array.shape[self.xdim]:
|
|
end = self.array.shape[self.xdim]
|
|
|
|
y = self.array[int(start):int(end)]
|
|
dim = self.array.dimensions[self.xdim]
|
|
x = np.asarray(dim.axis(len(y), int(start)))
|
|
|
|
if len(self.lines) == 0:
|
|
l, = self.axis.plot(x, y, label=self.array.name)
|
|
self.lines.append(l)
|
|
else:
|
|
self.lines[0].set_ydata(y)
|
|
self.lines[0].set_xdata(x)
|
|
|
|
self.axis.set_xlim([x[0], x[-1]])
|
|
|
|
def __draw_2d(self, start, end):
|
|
if start < 0:
|
|
start = 0
|
|
if end > self.array.shape[self.xdim]:
|
|
end = self.array.shape[self.xdim]
|
|
|
|
x_dimension = self.array.dimensions[self.xdim]
|
|
x = np.asarray(x_dimension.axis(int(end-start), start))
|
|
y_dimension = self.array.dimensions[1-self.xdim]
|
|
labels = y_dimension.labels
|
|
if len(labels) == 0:
|
|
labels = list(map(str, range(self.array.shape[1-self.xdim])))
|
|
|
|
for i, l in enumerate(labels):
|
|
if (self.xdim == 0):
|
|
y = self.array[int(start):int(end), i]
|
|
else:
|
|
y = self.array[i, int(start):int(end)]
|
|
|
|
if len(self.lines) <= i:
|
|
ll, = self.axis.plot(x, y, label=l)
|
|
self.lines.append(ll)
|
|
else:
|
|
self.lines[i].set_ydata(y)
|
|
self.lines[i].set_xdata(x)
|
|
|
|
self.axis.set_xlim([x[0], x[-1]])
|
|
|
|
def plot_array_1d(self):
|
|
self.__draw_1d(0, self.maxpoints)
|
|
xlabel = create_label(self.array.dimensions[self.xdim])
|
|
ylabel = create_label(self.array)
|
|
self.axis.set_xlabel(xlabel)
|
|
self.axis.set_ylabel(ylabel)
|
|
return self.axis
|
|
|
|
def plot_array_2d(self):
|
|
self.__draw_2d(0, self.maxpoints)
|
|
xlabel = create_label(self.array.dimensions[self.xdim])
|
|
ylabel = create_label(self.array)
|
|
self.axis.set_xlabel(xlabel)
|
|
self.axis.set_ylabel(ylabel)
|
|
self.axis.legend(loc=1)
|
|
return self.axis
|
|
|
|
|
|
class MplCanvas(FigureCanvasQTAgg):
|
|
|
|
def __init__(self, parent=None, width=5, height=4, dpi=100):
|
|
fig = Figure(figsize=(width, height), dpi=dpi)
|
|
self.axes = fig.add_subplot(111)
|
|
super(MplCanvas, self).__init__(fig)
|
|
|
|
|
|
class PlotScreen(QWidget):
|
|
close_signal = pyqtSignal()
|
|
|
|
def __init__(self, parent=None) -> None:
|
|
super().__init__(parent=parent)
|
|
self._file_handler = FileHandler()
|
|
sc = MplCanvas(self, width=5, height=4, dpi=100)
|
|
sc.axes.plot([0,1,2,3,4], [10,1,20,3,40])
|
|
|
|
self.setLayout(QVBoxLayout())
|
|
self.layout().addWidget(sc)
|
|
|
|
close_btn = QPushButton("close")
|
|
close_btn.clicked.connect(self.on_close)
|
|
|
|
self.layout().addWidget(close_btn)
|
|
self._data_view = None
|
|
|
|
def on_close(self):
|
|
self.close_signal.emit()
|
|
|
|
def plot(self, item):
|
|
print("plot!", item)
|
|
print(item.entity_type, item.shape)
|
|
self._data_view = DataView(item, self._file_handler)
|
|
self._data_view.request_more()
|
|
print(self._data_view)
|
|
|
|
|