[plot] works for 1 and 2d categories

This commit is contained in:
Jan Grewe 2021-04-02 20:06:18 +02:00
parent 41ccbfb0a1
commit d77b0b9e9a
3 changed files with 107 additions and 127 deletions

View File

@ -1,12 +1,10 @@
from nixview.util import dataview
from nixview.util.enums import PlotterTypes from nixview.util.enums import PlotterTypes
from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget
from PyQt5.QtCore import pyqtSignal, Qt from PyQt5.QtCore import pyqtSignal, Qt
import matplotlib import matplotlib
matplotlib.use('Qt5Agg') matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure from matplotlib.figure import Figure
import nixio as nix
import numpy as np import numpy as np
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -50,7 +48,6 @@ class MplCanvas(FigureCanvas):
self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes) self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes)
self._figure.canvas.mpl_connect('pick_event', self.on_pick) self._figure.canvas.mpl_connect('pick_event', self.on_pick)
def on_enter_figure(self, event): def on_enter_figure(self, event):
# print('enter_figure', event.canvas.figure) # print('enter_figure', event.canvas.figure)
# event.canvas.figure.patch.set_facecolor('red') # event.canvas.figure.patch.set_facecolor('red')
@ -207,82 +204,55 @@ class EventPlotter(Plotter):
class CategoryPlotter(Plotter): class CategoryPlotter(Plotter):
def __init__(self, data_array, xdim=-1): def __init__(self, file_handler, item, data_view, parent=None):
self.array = data_array super().__init__(file_handler, item, data_view, parent)
self.dim_count = len(self._dataview.full_shape)
self._xdim = self._item.best_xdim
self.bars = [] self.bars = []
"""
if xdim == -1:
self.xdim = guess_best_xdim(self.array)
elif xdim > 2:
raise ValueError("CategoryPlotter: xdim is larger than 2! "
"Cannot plot that kind of data")
else:
self.xdim = xdim
"""
def plot(self, axis=None): def plot(self):
if axis is None: if len(self._dataview.full_shape) == 1:
self.fig = plt.figure()
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
self.axis.set_title(self.array.name)
else:
self.fig = axis.figure
self.axis = axis
if len(self.array.dimensions) == 1:
return self.plot_1d() return self.plot_1d()
elif len(self.array.dimensions) == 2: elif len(self._dataview.full_shape) == 2:
return self.plot_2d() return self.plot_2d()
else: else:
return None return None
def plot_1d(self): def plot_1d(self):
data = self.array[:]
dim = self.array.dimensions[self.xdim]
categories = None categories = None
if dim.dimension_type == nix.DimensionType.Set: ylabel = create_label(self._item)
categories = list(dim.labels) categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim])
else: if categories is None or len(categories) == 0:
return None categories = ["Cat-%i" % i for i in range(len(categories))]
if categories is None: ylabel = create_label(self._item)
categories = ["Cat-%i" % i for i in range(len(data))]
ylabel = create_label(self.array) self.bars.append(self.axis.bar(range(1, len(categories) + 1), self._dataview._buffer, tick_label=categories))
if len(categories) == 0:
raise ValueError("Cannot plot a bar chart without any labels")
self.bars.append(self.axis.bar(range(1, len(categories)+1), data,
tick_label=categories))
self.axis.set_ylabel(ylabel) self.axis.set_ylabel(ylabel)
return self.axis return self.axis
def plot_2d(self): def plot_2d(self):
data = self.array[:]
if self.xdim == 1:
data = data.T
categories = None categories = None
dim = self.array.dimensions[self.xdim] ylabel = create_label(self._item)
if dim.dimension_type == nix.DimensionType.Set: data = self._dataview._buffer
categories = list(dim.labels) if self._item.best_xdim == 1:
if len(categories) == 0: data = data.T
categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])] categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim])
if categories is None or len(categories) == 0:
dim = self.array.dimensions[1-self.xdim] categories = ["Cat-%i" % i for i in range(self._item.shape[self._item.best_xdim])]
series_names = []
if dim.dimension_type == nix.DimensionType.Set: series = self._file_handler.request_axis(self._item.block_id, self._item.id, 1 - self._item.best_xdim, self._item.shape[1 - self._item.best_xdim])
series_names = list(dim.labels) if len(series) == 0:
if len(series_names) == 0: series = ["Series-%i" % i for i in range(self._item.shape[1 - self._item.best_xdim])]
series_names = ["Series-%i" % i
for i in range(data.shape[1-self.xdim])] ylabel = create_label(self._item)
bar_width = 1/len(series) * 0.75
bar_width = 1/data.shape[1] * 0.75 for i in range(len(series)):
for i in range(data.shape[1]): x_values = np.arange(len(categories)) + i * bar_width
x_values = np.arange(data.shape[0]) + i * bar_width self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0])
self.bars.append(self.axis.bar(x_values, data[:, i], self.axis.set_xticks(np.arange(len(categories)) + len(series) * bar_width/2)
width=bar_width,
align="center")[0])
self.axis.set_xticks(np.arange(data.shape[0]) +
data.shape[1] * bar_width/2)
self.axis.set_xticklabels(categories) self.axis.set_xticklabels(categories)
self.axis.legend(self.bars, series_names, loc=1) self.axis.legend(self.bars, series, loc=1)
return self.axis self.axis.set_ylabel(ylabel)
class ImagePlotter(Plotter): class ImagePlotter(Plotter):
@ -616,3 +586,11 @@ class PlotScreen(QWidget):
self.plotter = EventPlotter(self._file_handler, item, self._data_view) self.plotter = EventPlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter) self._container.set_plotter(self.plotter)
self.plotter.plot() self.plotter.plot()
elif item.suggested_plotter == PlotterTypes.CategoryPlotter:
self._zoom_slider.setEnabled(False)
self._pan_slider.setEnabled(False)
self.plotter = CategoryPlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter)
self.plotter.plot()
else:
self._container.set_plotter(None)

View File

@ -59,7 +59,6 @@ class DataView():
self._buffer = np.empty(self._full_shape) self._buffer = np.empty(self._full_shape)
except: except:
raise ValueError("Error reserving buffer! Cannot handle so many data points!") #FIXME raise ValueError("Error reserving buffer! Cannot handle so many data points!") #FIXME
print("init buffer")
@property @property
def fully_loaded(self): def fully_loaded(self):
@ -79,4 +78,3 @@ class DataView():
r += " max chunk size: " + str(self._count) r += " max chunk size: " + str(self._count)
r += " is fully loaded: " + str(self.fully_loaded) r += " is fully loaded: " + str(self.fully_loaded)
return r return r

View File

@ -8,6 +8,7 @@ from nixview.util.enums import NodeType, PlotterTypes
class Singleton(type): class Singleton(type):
_instances = {} _instances = {}
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
if cls not in cls._instances: if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
@ -60,7 +61,7 @@ class FileHandler(metaclass=Singleton):
self._nix_file = nix.File.open(filename, nix.FileMode.ReadOnly) self._nix_file = nix.File.open(filename, nix.FileMode.ReadOnly)
self._filename = filename self._filename = filename
self._file_descriptor = FileDescriptor(self.filename, self._nix_file.format, self._nix_file.version, self._file_descriptor = FileDescriptor(self.filename, self._nix_file.format, self._nix_file.version,
self._nix_file.created_at, self._nix_file.updated_at, os.path.getsize(self.filename)/1e+6) self._nix_file.created_at, self._nix_file.updated_at, os.path.getsize(self.filename) / 1e+6)
self.file_descriptor.block_count = len(self._nix_file.blocks) self.file_descriptor.block_count = len(self._nix_file.blocks)
for b in self._nix_file.blocks: for b in self._nix_file.blocks:
self.file_descriptor.data_array_count += len(b.data_arrays) self.file_descriptor.data_array_count += len(b.data_arrays)
@ -150,7 +151,7 @@ class FileHandler(metaclass=Singleton):
props = [] props = []
for p in section.props: for p in section.props:
value = "" value = ""
if self._file_version < (1,1,1): if self._file_version < (1, 1, 1):
vals = p.values vals = p.values
if len(vals) > 1: if len(vals) > 1:
value += "[" value += "["
@ -207,7 +208,7 @@ class FileHandler(metaclass=Singleton):
point_or_segment = "segment" if e.extent else "point" point_or_segment = "segment" if e.extent else "point"
start = str(e.position) start = str(e.position)
end = ("to " + str(tuple(np.array(e.position) + np.array(e.extent)))) if e.extent else "" end = ("to " + str(tuple(np.array(e.position) + np.array(e.extent)))) if e.extent else ""
itd.value = "tags %s %s %s" %(point_or_segment, start, end) itd.value = "tags %s %s %s" % (point_or_segment, start, end)
# TODO set the value to something meaningful for the various entity types # TODO set the value to something meaningful for the various entity types
return infos return infos
@ -262,10 +263,11 @@ class FileHandler(metaclass=Singleton):
dimensions = [] dimensions = []
for i, d in enumerate(da.dimensions): for i, d in enumerate(da.dimensions):
dim_name = "%i. dim: %s" % (i+1, d.label if hasattr(d, "label") else "") dim_name = "%i. dim: %s" % (i+1, d.label if hasattr(d, "label") else "")
dim_type= "%s %s" % (d.dimension_type, "dimension") dim_type = "%s %s" % (d.dimension_type, "dimension")
unit = d.unit if hasattr(d, "unit") else None unit = d.unit if hasattr(d, "unit") else None
label = d.label if hasattr(d, "label") else None label = d.label if hasattr(d, "label") else None
dimensions.append(ItemDescriptor(dim_name, type=dim_type, entity_type=NodeType.Dimension, block_id=block_id, unit=unit, label=label)) dimensions.append(ItemDescriptor(dim_name, type=dim_type, entity_type=NodeType.Dimension,
block_id=block_id, unit=unit, label=label))
return dimensions return dimensions
def request_axis(self, block_id, array_id, dimension_index, count, start=0): def request_axis(self, block_id, array_id, dimension_index, count, start=0):
@ -273,6 +275,7 @@ class FileHandler(metaclass=Singleton):
if da is None: if da is None:
b = self.get_block(block_id) b = self.get_block(block_id)
da = b.data_arrays[array_id] da = b.data_arrays[array_id]
dim = da.dimensions[dimension_index] dim = da.dimensions[dimension_index]
if dim.dimension_type == nix.DimensionType.Set: if dim.dimension_type == nix.DimensionType.Set:
labels = dim.labels labels = dim.labels
@ -320,20 +323,21 @@ class FileHandler(metaclass=Singleton):
if len(data_extent) == 1: if len(data_extent) == 1:
return 0 return 0
d1 = array.dimensions[0] d0 = array.dimensions[0]
d2 = array.dimensions[1] d1 = array.dimensions[1]
shape = array.data_extent
if d1.dimension_type == nix.DimensionType.Sample: if d0.dimension_type == nix.DimensionType.Sample:
return 0 return 0
elif d2.dimension_type == nix.DimensionType.Sample: elif d1.dimension_type == nix.DimensionType.Sample:
return 1 return 1
else: else:
if (d1.dimension_type == nix.DimensionType.Set) and \ if (d0.dimension_type == nix.DimensionType.Set) and \
(d2.dimension_type == nix.DimensionType.Range): (d1.dimension_type == nix.DimensionType.Range):
return 1 return 1
elif (d1.dimension_type == nix.DimensionType.Range) and \ elif (d0.dimension_type == nix.DimensionType.Set) and \
(d2.dimension_type == nix.DimensionType.Set): (d1.dimension_type == nix.DimensionType.Set):
return 0 return int(np.argmax(shape))
else: else:
return 0 return 0