[lineplotter] fix 2d plotting

This commit is contained in:
Jan Grewe 2021-01-21 16:46:38 +01:00
parent 8c9eeef076
commit fb4fc20442

View File

@ -1,7 +1,7 @@
from nixview.data_models.tree_model import PropertyTreeItem
from nixview.util import dataview
from nixview.util.enums import PlotterTypes
from PyQt5.QtWidgets import QHBoxLayout, QPushButton, QVBoxLayout, QWidget
from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget
from PyQt5.QtCore import QObject, pyqtSignal, Qt
import matplotlib
matplotlib.use('Qt5Agg')
@ -11,6 +11,7 @@ import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from IPython import embed
from nixview.util.file_handler import FileHandler
from nixview.util.dataview import DataView
@ -29,10 +30,11 @@ def create_label(item):
class MplCanvas(FigureCanvasQTAgg):
view_changed = pyqtSignal()
def __init__(self, parent=None, width=5, height=4, dpi=100):
figure = Figure(figsize=(width, height), dpi=dpi)
self.axes = figure.add_subplot(111)
self.axis = figure.add_subplot(111)
super(MplCanvas, self).__init__(figure)
self._figure = figure
@ -237,9 +239,9 @@ class LinePlotter(Plotter):
if self.dim_count > 2:
return
if self.dim_count == 1:
return self.plot_array_1d()
self.plot_array_1d()
else:
return self.plot_array_2d()
self.plot_array_2d()
def __add_slider(self):
steps = self.array.shape[self.xdim] / self.maxpoints
@ -283,21 +285,18 @@ class LinePlotter(Plotter):
def __draw_2d(self, start, end):
if start < 0:
start = 0
if end > self.array.shape[self.xdim]:
end = self.array.shape[self.xdim]
x_dimension = self.array.dimensions[self.xdim]
x = np.asarray(x_dimension.axis(int(end-start), start))
y_dimension = self.array.dimensions[1-self.xdim]
labels = y_dimension.labels
if len(labels) == 0:
labels = list(map(str, range(self.array.shape[1-self.xdim])))
if end > self._dataview.current_shape[self.xdim]:
end = self._dataview.current_shape[self.xdim]
for i, l in enumerate(labels):
x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start)
line_count = self._dataview.current_shape[1 - self.xdim]
line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0)
for i, l in enumerate(line_labels):
if (self.xdim == 0):
y = self.array[int(start):int(end), i]
y = self._dataview._buffer[int(start):int(end), i]
else:
y = self.array[i, int(start):int(end)]
y = self._dataview._buffer[i, int(start):int(end)]
if len(self.lines) <= i:
ll, = self.axis.plot(x, y, label=l)
@ -310,21 +309,21 @@ class LinePlotter(Plotter):
def plot_array_1d(self):
self.__draw_1d(0, self.maxpoints)
xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item)
self.axes.set_xlabel(xlabel)
self.axes.set_ylabel(ylabel)
return self.axes
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.view_changed.emit()
def plot_array_2d(self):
self.__draw_2d(0, self.maxpoints)
xlabel = create_label(self.array.dimensions[self.xdim])
ylabel = create_label(self.array)
xlabel = create_label(self.dimensions[self.xdim])
ylabel = create_label(self._item)
self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel)
self.axis.legend(loc=1)
return self.axis
self.view_changed.emit()
class PlotContainer(QWidget):
def __init__(self, parent=None) -> None:
@ -337,7 +336,8 @@ class PlotContainer(QWidget):
self.layout().removeWidget(self.plotter)
self.layout().addWidget(plotter)
self.plotter = plotter
class PlotScreen(QWidget):
close_signal = pyqtSignal()
@ -347,17 +347,67 @@ class PlotScreen(QWidget):
self.setLayout(QVBoxLayout())
self._container = PlotContainer(self)
self.layout().addWidget(self._container)
self.layout().addWidget(self._container)
close_btn = QPushButton("close")
close_btn.clicked.connect(self.on_close)
self._create_plot_controls()
self.layout().addWidget(close_btn)
self._data_view = None
def _create_plot_controls(self):
plot_controls = QGroupBox()
plot_controls.setFlat(True)
plot_controls.setLayout(QHBoxLayout())
plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._zoom_slider = QSlider(Qt.Horizontal)
self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
self._zoom_slider.setFixedWidth(120)
self._zoom_slider.setFixedHeight(20)
self._zoom_slider.setTickPosition(QSlider.TicksBelow)
self._zoom_slider.setSliderPosition(50)
self._zoom_slider.setMinimum(0)
self._zoom_slider.setMaximum(100)
self._zoom_slider.setTickInterval(25)
self._zoom_slider.valueChanged.connect(self.on_zoom)
self._pan_slider = QSlider(Qt.Horizontal)
self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self._pan_slider.setFixedHeight(20)
self._pan_slider.setTickPosition(QSlider.TicksBelow)
self._pan_slider.setSliderPosition(0)
self._pan_slider.setMinimum(0)
self._pan_slider.setMaximum(100)
self._pan_slider.setTickInterval(25)
self._pan_slider.valueChanged.connect(self.on_pan)
pl = QLabel("horiz. pos.:")
pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
pl.setMaximumWidth(150)
zl = QLabel("zoom:")
zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
zl.setMaximumWidth(75)
plot_controls.layout().addWidget(pl)
plot_controls.layout().addWidget(self._pan_slider)
plot_controls.layout().addWidget(zl)
plot_controls.layout().addWidget(QLabel("+"))
plot_controls.layout().addWidget(self._zoom_slider)
plot_controls.layout().addWidget(QLabel("-"))
self.layout().addWidget(plot_controls)
def on_close(self):
self.close_signal.emit()
def on_zoom(self, new_position):
print("zoom", new_position)
def on_pan(self, new_position):
print("pan", new_position)
def on_view_changed(self):
print("view changed!")
def plot(self, item):
try:
self._data_view = DataView(item, self._file_handler)
@ -370,6 +420,7 @@ class PlotScreen(QWidget):
self._data_view.request_more() # TODO this is just a test, needs to be removed
if item.suggested_plotter == PlotterTypes.LinePlotter:
self.plotter = LinePlotter(self._file_handler, item, self._data_view)
self.plotter.view_changed.connect(self.on_view_changed)
self._container.set_plotter(self.plotter)
self.plotter.plot()