import numpy as np
from nixview.constants import max_chunksize as chunksize

class DataView():

    def __init__(self, item_descriptor, file_handler) -> None:
        super().__init__()
        self._item_descriptor = item_descriptor
        self._file_handler = file_handler
        self._full_shape = item_descriptor.shape
        self._buffer = None
        self._offset = np.zeros(len(self._full_shape), dtype=int)
        self._fetched_data = np.zeros(len(self._full_shape), dtype=int)
        self._count = None
        self._cut_dim = None
        self.init_buffer()
        self.request_more()

    def request_more(self):
        if self.fully_loaded:
            return
        # first make sure, that the count is valid, i.e. inside data
        valid_count = self._file_handler.valid_count(self._full_shape, self._offset, self._count)
        sl = tuple([slice(o, o + c) for o, c in zip(self._offset, valid_count)])
        self._buffer[sl] = self._file_handler.request_data(self._item_descriptor, self._offset,
                                                           valid_count)
        new_ofst = np.zeros_like(self._offset)
        for i, x in enumerate(zip(self._offset, valid_count)):
            if i == self._cut_dim:
                new_ofst[i] = sum(x)
        
        self._offset = tuple(new_ofst)
        self._fetched_data = tuple([sum(x) for x in zip(self._fetched_data, valid_count)])
        
    def init_chunking(self):
        """decides on the chunks size for reading. Heuristic is based on the dimensionality of the data and the "best xdim" if available.
        If data is 2D the best xdim is loaded in chunks (if necessary) while the other is fully loaded. For 3D and more it is the last dimension that is cut. If the number of data points in the first n-1 dimensions exceeds the maximum chunksize (settings) an error will be thrown.
        """
        max_element_count = chunksize
        if self._item_descriptor.best_xdim is not None:
            cut_dim = self._item_descriptor.best_xdim 
        else:
            cut_dim = len(self._full_shape) - 1
            if np.prod(self._full_shape[:-1]) > chunksize:
                raise ValueError("Cannot load data in chunks! maxchunksize too small: product of elements in first %i dimensions exceeds max chunksize! (%i > %i)" % (len(self._full_shape) -1, np.prod(self._full_shape[:-1]), chunksize))
        chunk_shape = np.zeros(len(self._full_shape), dtype=int)
        for i, d in enumerate(self._full_shape):
            if i != cut_dim:
                chunk_shape[i] = d
                max_element_count /= d

        chunk_shape[cut_dim] = max_element_count
        self._cut_dim = cut_dim
        self._count = chunk_shape

    def init_buffer(self):
        self.init_chunking()
        try:
            self._buffer = np.empty(self._full_shape)
        except:
            raise ValueError("Error reserving buffer! Cannot handle so many data points!") #FIXME
        print("init buffer")

    @property
    def fully_loaded(self):
        return np.all(self._buffer is not None and self._fetched_data == self._full_shape)
    
    @property
    def full_shape(self):
        return self._full_shape
    
    @property
    def current_shape(self):
        return self._fetched_data
    
    def __str__(self) -> str:
        r = self._item_descriptor.name + " " + str(self._item_descriptor.entity_type)
        r += " buffer size: "  + str(self._buffer.shape) if self._buffer is not None else "" + "\n"
        r += " max chunk size: " + str(self._count)
        r += " is fully loaded: " + str(self.fully_loaded)
        return r