[plot] make imagePlotter work

This commit is contained in:
Jan Grewe 2021-02-28 00:30:42 +01:00
parent 17b3e99d43
commit c12fb006d1

View File

@ -102,9 +102,6 @@ class Plotter(MplCanvas):
self._item = item self._item = item
self._dataview = data_view self._dataview = data_view
def current_view(self):
raise NotImplementedError("current_view is not implemented on the current plotter")
@property @property
def is_full_view(self): def is_full_view(self):
raise NotImplementedError("is_full_view is not implemented on the current plotter") raise NotImplementedError("is_full_view is not implemented on the current plotter")
@ -242,40 +239,39 @@ class CategoryPlotter(Plotter):
class ImagePlotter(Plotter): class ImagePlotter(Plotter):
def __init__(self, data_array, xdim=-1): def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
self.array = data_array super().__init__(file_handler, item, data_view, parent)
self._dim_count = len(self._dataview.full_shape)
if self._dim_count > 3:
raise ValueError("ImagePlotter cannot plot data with more than 3 dimensions! Shape of %s is %s " %(self._item.name, str(self._dataview.full_shape)))
self.image = None self.image = None
def plot(self, axis=None): def plot(self):
dim_count = len(self.array.dimensions) while not self._dataview.fully_loaded:
if axis is None: self._dataview.request_more()
self.fig = plt.figure()
self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75]) self.axis.set_title(self._item.name)
self.axis.set_title(self.array.name) if self._dim_count == 2:
else:
self.fig = axis.figure
self.axis = axis
if dim_count == 2:
return self.plot_2d() return self.plot_2d()
elif dim_count == 3: elif self._dim_count == 3:
return self.plot_3d() return self.plot_3d()
else:
return None
def plot_2d(self): def plot_2d(self):
data = self.array[:] data = self._dataview._buffer[:]
x = self.array.dimensions[0].axis(data.shape[0]) if len(data.shape) == 3:
y = self.array.dimensions[1].axis(data.shape[1]) data = data.astype("uint8")
xlabel = create_label(self.array.dimensions[0]) x_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, self._dataview.full_shape[0], 0)
ylabel = create_label(self.array.dimensions[1]) y_values = self._file_handler.request_axis(self._item.block_id, self._item.id, 1, self._dataview.full_shape[1], 0)
self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]]) dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
xlabel = create_label(dimensions[0])
ylabel = create_label(dimensions[1])
self.image = self.axis.imshow(data, extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]])
self.axis.set_xlabel(xlabel) self.axis.set_xlabel(xlabel)
self.axis.set_ylabel(ylabel) self.axis.set_ylabel(ylabel)
self.axis.set
return self.axis return self.axis
def plot_3d(self): def plot_3d(self):
if self.array.shape[2] > 3: if self._dataview.full_shape[2] > 3:
print("cannot plot 3d data with more than 3 channels " print("cannot plot 3d data with more than 3 channels "
"in the third dim") "in the third dim")
return None return None
@ -311,22 +307,10 @@ class LinePlotter(Plotter):
self._view_xmin = 0 self._view_xmin = 0
self._view_xmax = 0 self._view_xmax = 0
self.axis.callbacks.connect('xlim_changed', self.on_xlims_change)
self.axis.callbacks.connect('ylim_changed', self.on_ylims_change)
self._zoom_level = 0 self._zoom_level = 0
self._segment_length = 0 self._segment_length = 0
# self.axis.callbacks.connect('xlim_changed', self.on_xlims_change)
def on_xlims_change(self, event_ax): # self.axis.callbacks.connect('ylim_changed', self.on_ylims_change)
#print("updated xlims: ", event_ax.get_xlim())
pass
def on_ylims_change(self, event_ax):
# print("updated ylims: ", event_ax.get_ylim())
pass
def current_view(self):
cv = []
return cv
@property @property
def is_full_view(self): def is_full_view(self):
@ -562,6 +546,8 @@ class PlotScreen(QWidget):
return return
if item.suggested_plotter == PlotterTypes.LinePlotter: if item.suggested_plotter == PlotterTypes.LinePlotter:
self._zoom_slider.setEnabled(True)
self._pan_slider.setEnabled(True)
zoom_slider_position = np.round(1000 / (self._data_view.full_shape[item.best_xdim] / 50000)) zoom_slider_position = np.round(1000 / (self._data_view.full_shape[item.best_xdim] / 50000))
self._software_slide = True self._software_slide = True
self._zoom_slider.setSliderPosition(zoom_slider_position) self._zoom_slider.setSliderPosition(zoom_slider_position)
@ -570,3 +556,9 @@ class PlotScreen(QWidget):
self._container.set_plotter(self.plotter) self._container.set_plotter(self.plotter)
self.plotter.plot(zoomlevel=zoom_slider_position/1000.) self.plotter.plot(zoomlevel=zoom_slider_position/1000.)
self._software_slide = False self._software_slide = False
if item.suggested_plotter == PlotterTypes.ImagePlotter:
self._zoom_slider.setEnabled(False)
self._pan_slider.setEnabled(False)
self.plotter = ImagePlotter(self._file_handler, item, self._data_view)
self._container.set_plotter(self.plotter)
self.plotter.plot()