diff --git a/nixview/ui/plotscreen.py b/nixview/ui/plotscreen.py index a24781d..00a60bd 100644 --- a/nixview/ui/plotscreen.py +++ b/nixview/ui/plotscreen.py @@ -1,12 +1,10 @@ -from nixview.util import dataview from nixview.util.enums import PlotterTypes from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget from PyQt5.QtCore import pyqtSignal, Qt import matplotlib matplotlib.use('Qt5Agg') -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.figure import Figure -import nixio as nix import numpy as np try: import matplotlib.pyplot as plt @@ -50,7 +48,6 @@ class MplCanvas(FigureCanvas): self._figure.canvas.mpl_connect('axes_leave_event', self.on_leave_axes) self._figure.canvas.mpl_connect('pick_event', self.on_pick) - def on_enter_figure(self, event): # print('enter_figure', event.canvas.figure) # event.canvas.figure.patch.set_facecolor('red') @@ -62,7 +59,7 @@ class MplCanvas(FigureCanvas): # event.canvas.figure.patch.set_facecolor('grey') # event.canvas.draw() pass - + def on_enter_axes(self, event): # print('enter_axes', event.inaxes) # event.inaxes.patch.set_facecolor('yellow') @@ -84,7 +81,7 @@ class MplCanvas(FigureCanvas): #def clear(self): # self.clear() - + #@property #def enter_figure(self): # return self._fig @@ -116,7 +113,7 @@ class EventPlotter(Plotter): def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): super().__init__(file_handler, item, data_view, parent) self.dim_count = len(self._dataview.full_shape) - + if xdim == -1: self.xdim = self._item.best_xdim else: @@ -149,22 +146,22 @@ class EventPlotter(Plotter): @property def horizontal_pan_position(self): return self._view_xmax/self._abs_xmax - + def horizontal_pan_to_position(self, new_position, zoomlevel): new_xmax = int(np.min([np.ceil(new_position * self._abs_xmax), self._abs_xmax])) segment_length = zoomlevel * self._abs_xmax start = np.max([0, new_xmax - segment_length]) while not self._dataview.fully_loaded and new_xmax < self._dataview.current_shape[self.xdim]: self._dataview.request_more() - + self.plot(start, zoomlevel) - + def plot(self, start=0, zoomlevel=1.0): if zoomlevel > 1: zoomlevel = 1.0 self._segment_length = zoomlevel * self._abs_xmax self._zoom_level = zoomlevel - + if self.dim_count == 1: return self.plot_1d(start) else: @@ -196,7 +193,7 @@ class EventPlotter(Plotter): self.figure.canvas.draw_idle() # self.axis.set_ylim([np.min(y_values), np.max(y_values)]) self.axis.set_xlim([x_values[0], x_values[-1]]) - + self.axis.set_ylim([0.5, 1.5]) self.axis.set_yticks([1.]) self.axis.set_yticklabels([]) @@ -207,82 +204,55 @@ class EventPlotter(Plotter): class CategoryPlotter(Plotter): - def __init__(self, data_array, xdim=-1): - self.array = data_array + def __init__(self, file_handler, item, data_view, parent=None): + super().__init__(file_handler, item, data_view, parent) + self.dim_count = len(self._dataview.full_shape) + self._xdim = self._item.best_xdim self.bars = [] - """ - if xdim == -1: - self.xdim = guess_best_xdim(self.array) - elif xdim > 2: - raise ValueError("CategoryPlotter: xdim is larger than 2! " - "Cannot plot that kind of data") - else: - self.xdim = xdim - """ - def plot(self, axis=None): - if axis is None: - self.fig = plt.figure() - self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) - self.axis.set_title(self.array.name) - else: - self.fig = axis.figure - self.axis = axis - if len(self.array.dimensions) == 1: + def plot(self): + if len(self._dataview.full_shape) == 1: return self.plot_1d() - elif len(self.array.dimensions) == 2: + elif len(self._dataview.full_shape) == 2: return self.plot_2d() - else: + else: return None def plot_1d(self): - data = self.array[:] - dim = self.array.dimensions[self.xdim] categories = None - if dim.dimension_type == nix.DimensionType.Set: - categories = list(dim.labels) - else: - return None - if categories is None: - categories = ["Cat-%i" % i for i in range(len(data))] - ylabel = create_label(self.array) - if len(categories) == 0: - raise ValueError("Cannot plot a bar chart without any labels") - self.bars.append(self.axis.bar(range(1, len(categories)+1), data, - tick_label=categories)) + ylabel = create_label(self._item) + categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim]) + if categories is None or len(categories) == 0: + categories = ["Cat-%i" % i for i in range(len(categories))] + ylabel = create_label(self._item) + + self.bars.append(self.axis.bar(range(1, len(categories) + 1), self._dataview._buffer, tick_label=categories)) self.axis.set_ylabel(ylabel) return self.axis def plot_2d(self): - data = self.array[:] - if self.xdim == 1: - data = data.T categories = None - dim = self.array.dimensions[self.xdim] - if dim.dimension_type == nix.DimensionType.Set: - categories = list(dim.labels) - if len(categories) == 0: - categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])] - - dim = self.array.dimensions[1-self.xdim] - series_names = [] - if dim.dimension_type == nix.DimensionType.Set: - series_names = list(dim.labels) - if len(series_names) == 0: - series_names = ["Series-%i" % i - for i in range(data.shape[1-self.xdim])] - - bar_width = 1/data.shape[1] * 0.75 - for i in range(data.shape[1]): - x_values = np.arange(data.shape[0]) + i * bar_width - self.bars.append(self.axis.bar(x_values, data[:, i], - width=bar_width, - align="center")[0]) - self.axis.set_xticks(np.arange(data.shape[0]) + - data.shape[1] * bar_width/2) + ylabel = create_label(self._item) + data = self._dataview._buffer + if self._item.best_xdim == 1: + data = data.T + categories = self._file_handler.request_axis(self._item.block_id, self._item.id, self._item.best_xdim, self._item.shape[self._item.best_xdim]) + if categories is None or len(categories) == 0: + categories = ["Cat-%i" % i for i in range(self._item.shape[self._item.best_xdim])] + + series = self._file_handler.request_axis(self._item.block_id, self._item.id, 1 - self._item.best_xdim, self._item.shape[1 - self._item.best_xdim]) + if len(series) == 0: + series = ["Series-%i" % i for i in range(self._item.shape[1 - self._item.best_xdim])] + + ylabel = create_label(self._item) + bar_width = 1/len(series) * 0.75 + for i in range(len(series)): + x_values = np.arange(len(categories)) + i * bar_width + self.bars.append(self.axis.bar(x_values, data[:, i], width=bar_width, align="center")[0]) + self.axis.set_xticks(np.arange(len(categories)) + len(series) * bar_width/2) self.axis.set_xticklabels(categories) - self.axis.legend(self.bars, series_names, loc=1) - return self.axis + self.axis.legend(self.bars, series, loc=1) + self.axis.set_ylabel(ylabel) class ImagePlotter(Plotter): @@ -359,7 +329,7 @@ class LinePlotter(Plotter): self._segment_length = 0 # self.axis.callbacks.connect('xlim_changed', self.on_xlims_change) # self.axis.callbacks.connect('ylim_changed', self.on_ylims_change) - + @property def is_full_view(self): full = self._data_xmin == self._view_xmin and self._data_xmax == self._view_xmax @@ -368,7 +338,7 @@ class LinePlotter(Plotter): @property def can_pan_horizontally(self): return self.can_pan_left or self.can_pan_right - + @property def can_pan_left(self): return self._view_xmin > self._abs_xmin @@ -380,7 +350,7 @@ class LinePlotter(Plotter): @property def horizontal_pan_position(self): return self._view_xmax/self._abs_xmax - + def horizontal_pan_to_position(self, new_position, zoomlevel): new_xmax = int(np.min([np.ceil(new_position * self._abs_xmax), self._abs_xmax])) segment_length = zoomlevel * self._abs_xmax @@ -389,13 +359,13 @@ class LinePlotter(Plotter): self._dataview.request_more() self.plot(start, zoomlevel) - + def on_zoom_in(self, new_position): print("plotter ZOOM In!", new_position) - + def on_zoom_out(self, new_position): print("plotter ZOOM out!", new_position) - + def plot(self, start=0, zoomlevel=1.0): if zoomlevel > 1: zoomlevel = 1.0 @@ -413,11 +383,11 @@ class LinePlotter(Plotter): self._data_xmin = display_x_min if self._data_xmax is None or display_xmax > self._data_xmax: self._data_xmax = display_xmax - + def _update_current_view(self, current_xmin, current_xmax): self._view_xmax = current_xmax self._view_xmin = current_xmin - + def __draw_1d(self, start, end): """ draw the data from start to end index. @@ -519,7 +489,7 @@ class PlotScreen(QWidget): self._create_plot_controls() self.layout().addWidget(close_btn) self._data_view = None - + self._software_slide = False self.plotter = None @@ -616,3 +586,11 @@ class PlotScreen(QWidget): self.plotter = EventPlotter(self._file_handler, item, self._data_view) self._container.set_plotter(self.plotter) self.plotter.plot() + elif item.suggested_plotter == PlotterTypes.CategoryPlotter: + self._zoom_slider.setEnabled(False) + self._pan_slider.setEnabled(False) + self.plotter = CategoryPlotter(self._file_handler, item, self._data_view) + self._container.set_plotter(self.plotter) + self.plotter.plot() + else: + self._container.set_plotter(None) diff --git a/nixview/util/dataview.py b/nixview/util/dataview.py index 9140bf3..1249b32 100644 --- a/nixview/util/dataview.py +++ b/nixview/util/dataview.py @@ -28,10 +28,10 @@ class DataView(): for i, x in enumerate(zip(self._offset, valid_count)): if i == self._cut_dim: new_ofst[i] = sum(x) - + self._offset = tuple(new_ofst) self._fetched_data = tuple([sum(x) for x in zip(self._fetched_data, valid_count)]) - + def init_chunking(self): """decides on the chunks size for reading. Heuristic is based on the dimensionality of the data and the "best xdim" if available. If data is 2D the best xdim is loaded in chunks (if necessary) while the other is fully loaded. For 3D and more it is the last dimension that is cut. If the number of data points in the first n-1 dimensions exceeds the maximum chunksize (settings) an error will be thrown. @@ -59,16 +59,15 @@ class DataView(): self._buffer = np.empty(self._full_shape) except: raise ValueError("Error reserving buffer! Cannot handle so many data points!") #FIXME - print("init buffer") @property def fully_loaded(self): return np.all(self._buffer is not None and self._fetched_data == self._full_shape) - + @property def full_shape(self): return self._full_shape - + @property def current_shape(self): return self._fetched_data @@ -79,4 +78,3 @@ class DataView(): r += " max chunk size: " + str(self._count) r += " is fully loaded: " + str(self.fully_loaded) return r - \ No newline at end of file diff --git a/nixview/util/file_handler.py b/nixview/util/file_handler.py index 9cb514c..c990ae5 100644 --- a/nixview/util/file_handler.py +++ b/nixview/util/file_handler.py @@ -8,6 +8,7 @@ from nixview.util.enums import NodeType, PlotterTypes class Singleton(type): _instances = {} + def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) @@ -19,29 +20,29 @@ class EntityBuffer(): def __init__(self) -> None: super().__init__() self._buffer = {} - + def put(self, entity): if not hasattr(entity, "id"): return id = entity.id if id not in self._buffer.keys(): self._buffer[id] = entity - + def has(self, id): return id in self._buffer.keys() - + def get(self, id): if self.has(id): return self._buffer[id] else: return None - + def clear(self): self._buffer.clear() class FileHandler(metaclass=Singleton): - + def __init__(self) -> None: super().__init__() self._filename = None @@ -50,24 +51,24 @@ class FileHandler(metaclass=Singleton): self._entity_buffer = EntityBuffer() self._file_descriptor = None self._file_version = None - + def open(self, filename): self.close() - + if not os.path.exists(filename): return False, "File %s could not be found!" % filename try: self._nix_file = nix.File.open(filename, nix.FileMode.ReadOnly) self._filename = filename - self._file_descriptor = FileDescriptor(self.filename, self._nix_file.format, self._nix_file.version, - self._nix_file.created_at, self._nix_file.updated_at, os.path.getsize(self.filename)/1e+6) + self._file_descriptor = FileDescriptor(self.filename, self._nix_file.format, self._nix_file.version, + self._nix_file.created_at, self._nix_file.updated_at, os.path.getsize(self.filename) / 1e+6) self.file_descriptor.block_count = len(self._nix_file.blocks) for b in self._nix_file.blocks: self.file_descriptor.data_array_count += len(b.data_arrays) self.file_descriptor.group_count += len(b.groups) self.file_descriptor.tag_count += len(b.tags) self.file_descriptor.tag_count += len(b.multi_tags) - if hasattr(b, "data_frames"): + if hasattr(b, "data_frames"): self.file_descriptor.data_frame_count += len(b.data_frames) self._file_version = self._nix_file.version return True, "Successfully opened file %s." % filename.split(os.sep)[-1] @@ -92,11 +93,11 @@ class FileHandler(metaclass=Singleton): @property def is_valid(self): return self._nix_file is not None and self._nix_file.is_open() - + @property def filename(self): return self._filename - + def valid_count(self, shape, offset, count): valid_count = np.empty(len(shape), dtype=int) for i, (o, c) in enumerate(zip(offset, count)): @@ -105,13 +106,13 @@ class FileHandler(metaclass=Singleton): else: valid_count[i] = c return valid_count - + def count_is_valid(self, shape, offset, count): res = True for s, o, c in zip(shape, offset, count): res = res and o + c <= s return res - + def request_data(self, entity_descriptor, offset=None, count=None): entity = self._entity_buffer.get(entity_descriptor.id) if entity is None: @@ -131,7 +132,7 @@ class FileHandler(metaclass=Singleton): else: item = ItemDescriptor(fs.name, fs.id, fs.type, definition=fs.definition, entity_type=NodeType.Section) return item - + def request_metadata(self, root_id=None, depth=1): """[summary] @@ -145,19 +146,19 @@ class FileHandler(metaclass=Singleton): self._entity_buffer.put(s) sub_sections.append(ItemDescriptor(s.name, s.id, s.type, definition=s.definition, entity_type=NodeType.Section)) return sub_sections - + def get_properties(section): props = [] for p in section.props: value = "" - if self._file_version < (1,1,1): - vals = p.values + if self._file_version < (1, 1, 1): + vals = p.values if len(vals) > 1: value += "[" value += ",".join(map(str, [v.value for v in vals])) value += "]" else: - value = str(vals[0].value) + value = str(vals[0].value) else: vals = p.values value += "[" @@ -207,7 +208,7 @@ class FileHandler(metaclass=Singleton): point_or_segment = "segment" if e.extent else "point" start = str(e.position) end = ("to " + str(tuple(np.array(e.position) + np.array(e.extent)))) if e.extent else "" - itd.value = "tags %s %s %s" %(point_or_segment, start, end) + itd.value = "tags %s %s %s" % (point_or_segment, start, end) # TODO set the value to something meaningful for the various entity types return infos @@ -219,7 +220,7 @@ class FileHandler(metaclass=Singleton): if not b: b = self._nix_file.blocks[id] return b - + def request_data_arrays(self, block_id): b = self.get_block(block_id) return self._entity_info(b.data_arrays, block_id, NodeType.DataArray) @@ -262,10 +263,11 @@ class FileHandler(metaclass=Singleton): dimensions = [] for i, d in enumerate(da.dimensions): dim_name = "%i. dim: %s" % (i+1, d.label if hasattr(d, "label") else "") - dim_type= "%s %s" % (d.dimension_type, "dimension") + dim_type = "%s %s" % (d.dimension_type, "dimension") unit = d.unit if hasattr(d, "unit") else None label = d.label if hasattr(d, "label") else None - dimensions.append(ItemDescriptor(dim_name, type=dim_type, entity_type=NodeType.Dimension, block_id=block_id, unit=unit, label=label)) + dimensions.append(ItemDescriptor(dim_name, type=dim_type, entity_type=NodeType.Dimension, + block_id=block_id, unit=unit, label=label)) return dimensions def request_axis(self, block_id, array_id, dimension_index, count, start=0): @@ -273,6 +275,7 @@ class FileHandler(metaclass=Singleton): if da is None: b = self.get_block(block_id) da = b.data_arrays[array_id] + dim = da.dimensions[dimension_index] if dim.dimension_type == nix.DimensionType.Set: labels = dim.labels @@ -280,7 +283,7 @@ class FileHandler(metaclass=Singleton): raise ValueError("Invalid argument for start or count for SetDimension") axis = labels[start:start + count] if len(labels) == 0: - axis = list(map(str, range(start, start+count))) + axis = list(map(str, range(start, start+count))) else: axis = np.asarray(dim.axis(count, start)) return axis @@ -311,7 +314,7 @@ class FileHandler(metaclass=Singleton): for src in srcs: sources.extend(get_subsources(src)) return sources - + def guess_best_xdim(self, array): data_extent = array.shape if len(data_extent) > 2: @@ -320,23 +323,24 @@ class FileHandler(metaclass=Singleton): if len(data_extent) == 1: return 0 - d1 = array.dimensions[0] - d2 = array.dimensions[1] + d0 = array.dimensions[0] + d1 = array.dimensions[1] + shape = array.data_extent - if d1.dimension_type == nix.DimensionType.Sample: + if d0.dimension_type == nix.DimensionType.Sample: return 0 - elif d2.dimension_type == nix.DimensionType.Sample: + elif d1.dimension_type == nix.DimensionType.Sample: return 1 else: - if (d1.dimension_type == nix.DimensionType.Set) and \ - (d2.dimension_type == nix.DimensionType.Range): + if (d0.dimension_type == nix.DimensionType.Set) and \ + (d1.dimension_type == nix.DimensionType.Range): return 1 - elif (d1.dimension_type == nix.DimensionType.Range) and \ - (d2.dimension_type == nix.DimensionType.Set): - return 0 + elif (d0.dimension_type == nix.DimensionType.Set) and \ + (d1.dimension_type == nix.DimensionType.Set): + return int(np.argmax(shape)) else: return 0 - + def suggested_plotter(self, array): if len(array.dimensions) > 3: print("cannot handle more than 3D") @@ -358,19 +362,19 @@ class FileHandler(metaclass=Singleton): elif dim_count == 2: if dim_types[0] == nix.DimensionType.Sample: if dim_types[1] == nix.DimensionType.Sample or \ - dim_types[1] == nix.DimensionType.Range: + dim_types[1] == nix.DimensionType.Range: return PlotterTypes.ImagePlotter else: return PlotterTypes.LinePlotter elif dim_types[0] == nix.DimensionType.Range: if dim_types[1] == nix.DimensionType.Sample or \ - dim_types[1] == nix.DimensionType.Range: + dim_types[1] == nix.DimensionType.Range: return PlotterTypes.ImagePlotter else: return PlotterTypes.LinePlotter elif dim_types[0] == nix.DimensionType.Set: if dim_types[1] == nix.DimensionType.Sample or \ - dim_types[1] == nix.DimensionType.Range: + dim_types[1] == nix.DimensionType.Range: return PlotterTypes.LinePlotter else: return PlotterTypes.CategoryPlotter