[lineplotter] fix 2d plotting

This commit is contained in:
Jan Grewe 2021-01-21 16:46:38 +01:00
parent 8c9eeef076
commit fb4fc20442

View File

@ -1,7 +1,7 @@
from nixview.data_models.tree_model import PropertyTreeItem from nixview.data_models.tree_model import PropertyTreeItem
from nixview.util import dataview from nixview.util import dataview
from nixview.util.enums import PlotterTypes from nixview.util.enums import PlotterTypes
from PyQt5.QtWidgets import QHBoxLayout, QPushButton, QVBoxLayout, QWidget from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget
from PyQt5.QtCore import QObject, pyqtSignal, Qt from PyQt5.QtCore import QObject, pyqtSignal, Qt
import matplotlib import matplotlib
matplotlib.use('Qt5Agg') matplotlib.use('Qt5Agg')
@ -11,6 +11,7 @@ import nixio as nix
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.widgets import Slider from matplotlib.widgets import Slider
from IPython import embed
from nixview.util.file_handler import FileHandler from nixview.util.file_handler import FileHandler
from nixview.util.dataview import DataView from nixview.util.dataview import DataView
@ -29,10 +30,11 @@ def create_label(item):
class MplCanvas(FigureCanvasQTAgg): class MplCanvas(FigureCanvasQTAgg):
view_changed = pyqtSignal()
def __init__(self, parent=None, width=5, height=4, dpi=100): def __init__(self, parent=None, width=5, height=4, dpi=100):
figure = Figure(figsize=(width, height), dpi=dpi) figure = Figure(figsize=(width, height), dpi=dpi)
self.axes = figure.add_subplot(111) self.axis = figure.add_subplot(111)
super(MplCanvas, self).__init__(figure) super(MplCanvas, self).__init__(figure)
self._figure = figure self._figure = figure
@ -237,9 +239,9 @@ class LinePlotter(Plotter):
if self.dim_count > 2: if self.dim_count > 2:
return return
if self.dim_count == 1: if self.dim_count == 1:
return self.plot_array_1d() self.plot_array_1d()
else: else:
return self.plot_array_2d() self.plot_array_2d()
def __add_slider(self): def __add_slider(self):
steps = self.array.shape[self.xdim] / self.maxpoints steps = self.array.shape[self.xdim] / self.maxpoints
@ -283,21 +285,18 @@ class LinePlotter(Plotter):
def __draw_2d(self, start, end): def __draw_2d(self, start, end):
if start < 0: if start < 0:
start = 0 start = 0
if end > self.array.shape[self.xdim]: if end > self._dataview.current_shape[self.xdim]:
end = self.array.shape[self.xdim] end = self._dataview.current_shape[self.xdim]
x_dimension = self.array.dimensions[self.xdim] x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start)
x = np.asarray(x_dimension.axis(int(end-start), start)) line_count = self._dataview.current_shape[1 - self.xdim]
y_dimension = self.array.dimensions[1-self.xdim] line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0)
labels = y_dimension.labels
if len(labels) == 0:
labels = list(map(str, range(self.array.shape[1-self.xdim])))
for i, l in enumerate(labels): for i, l in enumerate(line_labels):
if (self.xdim == 0): if (self.xdim == 0):
y = self.array[int(start):int(end), i] y = self._dataview._buffer[int(start):int(end), i]
else: else:
y = self.array[i, int(start):int(end)] y = self._dataview._buffer[i, int(start):int(end)]
if len(self.lines) <= i: if len(self.lines) <= i:
ll, = self.axis.plot(x, y, label=l) ll, = self.axis.plot(x, y, label=l)
@ -310,21 +309,21 @@ class LinePlotter(Plotter):
def plot_array_1d(self): def plot_array_1d(self):
self.__draw_1d(0, self.maxpoints) self.__draw_1d(0, self.maxpoints)
xlabel = create_label(self.dimensions[self.xdim]) xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item) ylabel = create_label(self._item)
self.axes.set_xlabel(xlabel) self.axis.set_xlabel(xlabel)
self.axes.set_ylabel(ylabel) self.axis.set_ylabel(ylabel)
return self.axes self.view_changed.emit()
def plot_array_2d(self): def plot_array_2d(self):
self.__draw_2d(0, self.maxpoints) self.__draw_2d(0, self.maxpoints)
xlabel = create_label(self.array.dimensions[self.xdim]) xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self.array) ylabel = create_label(self._item)
self.axis.set_xlabel(xlabel) self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel) self.axis.set_ylabel(ylabel)
self.axis.legend(loc=1) self.axis.legend(loc=1)
return self.axis self.view_changed.emit()
class PlotContainer(QWidget): class PlotContainer(QWidget):
def __init__(self, parent=None) -> None: def __init__(self, parent=None) -> None:
@ -338,6 +337,7 @@ class PlotContainer(QWidget):
self.layout().addWidget(plotter) self.layout().addWidget(plotter)
self.plotter = plotter self.plotter = plotter
class PlotScreen(QWidget): class PlotScreen(QWidget):
close_signal = pyqtSignal() close_signal = pyqtSignal()
@ -352,12 +352,62 @@ class PlotScreen(QWidget):
close_btn = QPushButton("close") close_btn = QPushButton("close")
close_btn.clicked.connect(self.on_close) close_btn.clicked.connect(self.on_close)
self._create_plot_controls()
self.layout().addWidget(close_btn) self.layout().addWidget(close_btn)
self._data_view = None self._data_view = None
def _create_plot_controls(self):
plot_controls = QGroupBox()
plot_controls.setFlat(True)
plot_controls.setLayout(QHBoxLayout())
plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._zoom_slider = QSlider(Qt.Horizontal)
self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
self._zoom_slider.setFixedWidth(120)
self._zoom_slider.setFixedHeight(20)
self._zoom_slider.setTickPosition(QSlider.TicksBelow)
self._zoom_slider.setSliderPosition(50)
self._zoom_slider.setMinimum(0)
self._zoom_slider.setMaximum(100)
self._zoom_slider.setTickInterval(25)
self._zoom_slider.valueChanged.connect(self.on_zoom)
self._pan_slider = QSlider(Qt.Horizontal)
self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._pan_slider.setFixedHeight(20)
self._pan_slider.setTickPosition(QSlider.TicksBelow)
self._pan_slider.setSliderPosition(0)
self._pan_slider.setMinimum(0)
self._pan_slider.setMaximum(100)
self._pan_slider.setTickInterval(25)
self._pan_slider.valueChanged.connect(self.on_pan)
pl = QLabel("horiz. pos.:")
pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
pl.setMaximumWidth(150)
zl = QLabel("zoom:")
zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
zl.setMaximumWidth(75)
plot_controls.layout().addWidget(pl)
plot_controls.layout().addWidget(self._pan_slider)
plot_controls.layout().addWidget(zl)
plot_controls.layout().addWidget(QLabel("+"))
plot_controls.layout().addWidget(self._zoom_slider)
plot_controls.layout().addWidget(QLabel("-"))
self.layout().addWidget(plot_controls)
def on_close(self): def on_close(self):
self.close_signal.emit() self.close_signal.emit()
def on_zoom(self, new_position):
print("zoom", new_position)
def on_pan(self, new_position):
print("pan", new_position)
def on_view_changed(self):
print("view changed!")
def plot(self, item): def plot(self, item):
try: try:
self._data_view = DataView(item, self._file_handler) self._data_view = DataView(item, self._file_handler)
@ -370,6 +420,7 @@ class PlotScreen(QWidget):
self._data_view.request_more() # TODO this is just a test, needs to be removed self._data_view.request_more() # TODO this is just a test, needs to be removed
if item.suggested_plotter == PlotterTypes.LinePlotter: if item.suggested_plotter == PlotterTypes.LinePlotter:
self.plotter = LinePlotter(self._file_handler, item, self._data_view) self.plotter = LinePlotter(self._file_handler, item, self._data_view)
self.plotter.view_changed.connect(self.on_view_changed)
self._container.set_plotter(self.plotter) self._container.set_plotter(self.plotter)
self.plotter.plot() self.plotter.plot()