import numpy as np
from nixview.constants import io_chunksize as chunksize

class DataView():

    def __init__(self, item_descriptor, file_handler) -> None:
        super().__init__()
        self._item_descriptor = item_descriptor
        self._file_handler = file_handler
        self._full_shape = item_descriptor.shape
        self._buffer = None
        self._offset = np.zeros(len(self._full_shape), dtype=int)
        self._count = None
        self._max_dim = None
        self.init_buffer()
        self.request_more()

    def request_more(self):
        if self.fully_loaded:
            print("all data fetched")
            return
        # first make sure, that the count is valid, i.e. inside data
        valid_count = self._file_handler.valid_count(self._full_shape, self._offset, self._count)
        sl = tuple([slice(o, o + c) for o, c in zip(self._offset, valid_count)])
        self._buffer[sl] = self._file_handler.request_data(self._item_descriptor, self._offset,
                                                           valid_count)
        self._offset = tuple([sum(x) for x in zip(self._offset, self._count)])

        
        #if data is not None and self._buffer is None:
        #    self._buffer = data
        #    self._offset = data.shape
        #else:
        #    from IPython import embed
        #    embed()
        
    def init_buffer(self):
        buffer_shape = np.zeros(len(self._full_shape), dtype=int)
        max_dim_count = chunksize
        max_dim = np.argmax(self._full_shape)
        for i, d in enumerate(self._full_shape):
            if i != max_dim:
                buffer_shape[i] = self._full_shape[i]
                max_dim_count /= self._full_shape[i]
        buffer_shape[max_dim] = max_dim_count
        self._count = buffer_shape
        self._max_dim = max_dim
        try:
            self._buffer = np.empty(self._full_shape)
        except:
            raise ValueError("Cannot handle so many data points!") #FIXME
    
    
    @property
    def fully_loaded(self):
        return self._buffer is not None and self._full_shape == self._offset
    
    def __str__(self) -> str:
        r = self._item_descriptor.name + " " + str(self._item_descriptor.entity_type)
        r += "buffer size: "  + str(self._buffer.shape) if self._buffer is not None else "" + "\n"
        r += "chunk size:" + str(self._count)
        r += "is fully loaded: " + str(self.fully_loaded)
        return r