diff --git a/nixview/ui/plotscreen.py b/nixview/ui/plotscreen.py index 8f3e616..f1c1b82 100644 --- a/nixview/ui/plotscreen.py +++ b/nixview/ui/plotscreen.py @@ -102,9 +102,6 @@ class Plotter(MplCanvas): self._item = item self._dataview = data_view - def current_view(self): - raise NotImplementedError("current_view is not implemented on the current plotter") - @property def is_full_view(self): raise NotImplementedError("is_full_view is not implemented on the current plotter") @@ -242,40 +239,39 @@ class CategoryPlotter(Plotter): class ImagePlotter(Plotter): - def __init__(self, data_array, xdim=-1): - self.array = data_array + def __init__(self, file_handler, item, data_view, xdim=-1, parent=None): + super().__init__(file_handler, item, data_view, parent) + self._dim_count = len(self._dataview.full_shape) + if self._dim_count > 3: + raise ValueError("ImagePlotter cannot plot data with more than 3 dimensions! Shape of %s is %s " %(self._item.name, str(self._dataview.full_shape))) self.image = None - def plot(self, axis=None): - dim_count = len(self.array.dimensions) - if axis is None: - self.fig = plt.figure() - self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) - self.axis.set_title(self.array.name) - else: - self.fig = axis.figure - self.axis = axis - if dim_count == 2: + def plot(self): + while not self._dataview.fully_loaded: + self._dataview.request_more() + + self.axis.set_title(self._item.name) + if self._dim_count == 2: return self.plot_2d() - elif dim_count == 3: + elif self._dim_count == 3: return self.plot_3d() - else: - return None - + def plot_2d(self): - data = self.array[:] - x = self.array.dimensions[0].axis(data.shape[0]) - y = self.array.dimensions[1].axis(data.shape[1]) - xlabel = create_label(self.array.dimensions[0]) - ylabel = create_label(self.array.dimensions[1]) - self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]]) + data = self._dataview._buffer[:] + if len(data.shape) == 3: + data = data.astype("uint8") + x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, self._dataview.full_shape[0], 0) + y_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 1, self._dataview.full_shape[1], 0) + dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id) + xlabel = create_label(dimensions[0]) + ylabel = create_label(dimensions[1]) + self.image = self.axis.imshow(data, extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]]) self.axis.set_xlabel(xlabel) self.axis.set_ylabel(ylabel) - self.axis.set return self.axis def plot_3d(self): - if self.array.shape[2] > 3: + if self._dataview.full_shape[2] > 3: print("cannot plot 3d data with more than 3 channels " "in the third dim") return None @@ -311,22 +307,10 @@ class LinePlotter(Plotter): self._view_xmin = 0 self._view_xmax = 0 - self.axis.callbacks.connect('xlim_changed', self.on_xlims_change) - self.axis.callbacks.connect('ylim_changed', self.on_ylims_change) self._zoom_level = 0 self._segment_length = 0 - - def on_xlims_change(self, event_ax): - #print("updated xlims: ", event_ax.get_xlim()) - pass - - def on_ylims_change(self, event_ax): - # print("updated ylims: ", event_ax.get_ylim()) - pass - - def current_view(self): - cv = [] - return cv + # self.axis.callbacks.connect('xlim_changed', self.on_xlims_change) + # self.axis.callbacks.connect('ylim_changed', self.on_ylims_change) @property def is_full_view(self): @@ -562,6 +546,8 @@ class PlotScreen(QWidget): return if item.suggested_plotter == PlotterTypes.LinePlotter: + self._zoom_slider.setEnabled(True) + self._pan_slider.setEnabled(True) zoom_slider_position = np.round(1000 / (self._data_view.full_shape[item.best_xdim] / 50000)) self._software_slide = True self._zoom_slider.setSliderPosition(zoom_slider_position) @@ -570,3 +556,9 @@ class PlotScreen(QWidget): self._container.set_plotter(self.plotter) self.plotter.plot(zoomlevel=zoom_slider_position/1000.) self._software_slide = False + if item.suggested_plotter == PlotterTypes.ImagePlotter: + self._zoom_slider.setEnabled(False) + self._pan_slider.setEnabled(False) + self.plotter = ImagePlotter(self._file_handler, item, self._data_view) + self._container.set_plotter(self.plotter) + self.plotter.plot()