from nixview.file_utils import suggested_plotter
import os
import nixio as nix
import numpy as np

from nixview.util.descriptors import FileDescriptor, ItemDescriptor
from nixview.util.enums import NodeType, PlotterTypes



class Singleton(type):
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]

class EntityBuffer():

    def __init__(self) -> None:
        super().__init__()
        self._buffer = {}
    
    def put(self, entity):
        if not hasattr(entity, "id"):
            return
        id = entity.id
        if id not in self._buffer.keys():
            self._buffer[id] = entity
    
    def has(self, id):
        return id in self._buffer.keys()
    
    def get(self, id):
        if self.has(id):
            return self._buffer[id]
        else:
            return None
    
    def clear(self):
        self._buffer.clear()


class FileHandler(metaclass=Singleton):
    
    def __init__(self) -> None:
        super().__init__()
        self._filename = None
        self._nix_file = None
        self._file_requests = []
        self._entity_buffer = EntityBuffer()
        self._file_descriptor = None
        self._file_version = None
    
    def open(self, filename):
        self.close()
      
        if not os.path.exists(filename):
            return False, "File %s could not be found!" % filename
        try:
            self._nix_file = nix.File.open(filename, nix.FileMode.ReadOnly)
            self._filename = filename
            self._file_descriptor = FileDescriptor(self.filename, self._nix_file.format, self._nix_file.version, 
                                                   self._nix_file.created_at, self._nix_file.updated_at, os.path.getsize(self.filename)/1e+6)
            self.file_descriptor.block_count = len(self._nix_file.blocks)
            for b in self._nix_file.blocks:
                self.file_descriptor.data_array_count += len(b.data_arrays)
                self.file_descriptor.group_count += len(b.groups)
                self.file_descriptor.tag_count += len(b.tags)
                self.file_descriptor.tag_count += len(b.multi_tags)
                if hasattr(b, "data_frames"): 
                    self.file_descriptor.data_frame_count += len(b.data_frames)
            self._file_version = self._nix_file.version
            return True, "Successfully opened file %s." % filename.split(os.sep)[-1]
        except RuntimeError as e:
            return False, "Failed to open file %s! \n Error message is: %s" % (filename, e)
        except OSError as e:
            return False, "Failed to open file %s! \n Error message is: %s\n Probably no nix file?!" % (filename, e)

    def close(self):
        if self._nix_file is not None and self._nix_file.is_open():
            self._nix_file.close()
            self._nix_file = None
            self._file_requests = []
            self._entity_buffer.clear()
            self._file_descriptor = None
            self._file_version = None

    @property
    def file_descriptor(self):
        return self._file_descriptor

    @property
    def is_valid(self):
        return self._nix_file is not None and self._nix_file.is_open()
    
    @property
    def filename(self):
        return self._filename
    
    def valid_count(self, shape, offset, count):
        valid_count = np.empty(len(shape), dtype=int)
        for i, (o, c) in enumerate(zip(offset, count)):
            if o + c > shape[i]:
                valid_count[i] = shape[i] - o
        return valid_count
    
    def count_is_valid(self, shape, offset, count):
        res = True
        for s, o, c in zip(shape, offset, count):
            res = res and o + c <= s
        return res
    
    def request_data(self, entity_descriptor, offset=None, count=None):
        entity = self._entity_buffer.get(entity_descriptor.id)
        if entity is None:
            raise ValueError("Entity is invalid: %s" % entity_descriptor)
        if not self.count_is_valid(entity.shape, offset, count):
            count = self.valid_count(entity.shape, offset, count)
        seg = tuple([slice(o, o + c) for o, c in zip(offset, count)])
        return entity[seg]

    def request_section_descriptor(self, id):
        fs = self._entity_buffer.get(id)
        if fs is None:
            found_section = self._nix_file.find_sections(lambda s: s.id == id)
            fs = found_section[0] if len(found_section) > 0 else None
        if fs is None:
            return None
        else:
            item = ItemDescriptor(fs.name, fs.id, fs.type, definition=fs.definition, entity_type=NodeType.Section)
            return item
    
    def request_metadata(self, root_id=None, depth=1):
        """[summary]

        Args:
            root_id ([type], optional): [description]. Defaults to None.
            depth (int, optional): [description]. Defaults to 1.
        """
        def get_subsections(section):
            sub_sections = []
            for s in section.sections:
                self._entity_buffer.put(s)
                sub_sections.append(ItemDescriptor(s.name, s.id, s.type, definition=s.definition, entity_type=NodeType.Section))
            return sub_sections
        
        def get_properties(section):
            props = []
            for p in section.props:
                value = ""
                if self._file_version < (1,1,1):
                    vals = p.values 
                    if len(vals) > 1:
                        value += "["
                        value += ",".join(map(str, [v.value for v in vals]))
                        value += "]"
                    else:
                        value = str(vals[0].value)    
                else:
                    vals = p.values
                    value += "["
                    value += ",".join(map(str, [v.value for v in vals]))
                    value += "]"
                if p.unit is not None:
                    value += " " + p.unit
                props.append(ItemDescriptor(p.name, p.id, value=value, entity_type=NodeType.Property))
            return props

        sections = []
        properties = []
        if root_id is None:
            sections = get_subsections(self._nix_file)
        else:
            fs = self._entity_buffer.get(root_id)
            if fs is None:
                found_section = self._nix_file.find_sections(lambda s: s.id == root_id)
                fs = found_section[0] if len(found_section) > 0 else None
            if fs is None:
                return sections, properties
            sections.extend(get_subsections(fs))
            properties.extend(get_properties(fs))
        return sections, properties

    def _entity_info(self, entities, block_id, entity_type):
        infos = []
        for e in entities:
            self._entity_buffer.put(e)
            itd = ItemDescriptor(e.name, e.id, e.type, definition=e.definition, entity_type=entity_type, block_id=block_id)
            section = e.metadata if hasattr(e, "metadata") else None
            itd.metadata_id = section.id if section is not None else None
            itd.data_type = e.data_type if hasattr(e, "data_type") else None
            itd.created_at = e.created_at if hasattr(e, "created_at") else None
            itd.updated_at = e.updated_at if hasattr(e, "updated") else None
            itd.shape = e.shape if hasattr(e, "shape") else None
            src = e.source if hasattr(e, "source") else None
            itd.source_id = src.id if src is not None else None
            infos.append(itd)
            if entity_type == NodeType.DataArray:
                itd.value = "%s %s entries" % (str(e.shape), e.dtype)
                itd.best_xdim = self.guess_best_xdim(e)
                itd.suggested_plotter = self.suggested_plotter(e)
            elif entity_type == NodeType.Tag:
                point_or_segment = "segment" if e.extent else "point"
                start = str(e.position)
                end = ("to " + str(tuple(np.array(e.position) + np.array(e.extent)))) if e.extent else ""
                itd.value = "tags %s %s %s" %(point_or_segment, start, end) 
            # TODO set the value to something meaningful for the various entity types
        return infos

    def request_blocks(self):
        return self._entity_info(self._nix_file.blocks, None, NodeType.Block)

    def get_block(self, id):
        b = b = self._entity_buffer.get(id)
        if not b:
            b = self._nix_file.blocks[id]
        return b
    
    def request_data_arrays(self, block_id):
        b = self.get_block(block_id)
        return self._entity_info(b.data_arrays, block_id, NodeType.DataArray)

    def request_tags(self, block_id):
        b = self.get_block(block_id)
        tags = self._entity_info(b.tags, block_id, NodeType.Tag)
        tags.extend(self._entity_info(b.multi_tags, block_id, NodeType.MultiTag))
        return tags

    def request_references(self, block_id, tag_id, is_mtag):
        b = self.get_block(block_id)
        t = self._entity_buffer.get(tag_id)
        if t is None:
            if is_mtag:
                t = b.multi_tags[tag_id]
            else:
                t = b.tags[tag_id]
        return self._entity_info(t.references, block_id, NodeType.DataArray)

    def request_features(self, block_id, tag_id, is_mtag):
        b = self.get_block(block_id)
        t = self._entity_buffer.get(tag_id)
        if t is None:
            if is_mtag:
                t = b.multi_tags[tag_id]
            else:
                t = b.tags[tag_id]
        feats = []
        for f in t.features:
            itd = ItemDescriptor(f.data.name, f.id, f.link_type, definition=f.data.definition, block_id=block_id, entity_type=NodeType.Feature)
            feats.append(itd)
        return feats

    def request_dimensions(self, block_id, array_id):
        da = self._entity_buffer.get(array_id)
        if da is None:
            b = self.get_block(block_id)
            da = b.data_arrays[array_id]
        dimensions = []
        for i, d in enumerate(da.dimensions):
            dim_name = "%i. dim: %s" % (i+1, d.label if hasattr(d, "label") else "")
            dim_type= "%s %s" % (d.dimension_type, "dimension")
            dimensions.append(ItemDescriptor(dim_name, type=dim_type, entity_type=NodeType.Dimension, block_id=block_id))
        return dimensions

    def request_data_frames(self, block_id):
        if self._nix_file.version[1] >= 2:
            b = self.get_block(block_id)
            return self._entity_info(b.data_frames, block_id, NodeType.DataFrame)
        return []

    def request_groups(self, block_id):
        b = self.get_block(block_id)
        return self._entity_info(b.groups, block_id, NodeType.Group)

    def request_sources(self, block_id, parent_source_id=None):
        def get_subsources(src):
            sub_sources = []
            for s in src.sources:
                self._entity_buffer.put(s)
                sub_sources.append(ItemDescriptor(s.name, s.id, s.type, definition=s.definition, entity_type=NodeType.Source))
            return sub_sources
        b = self.get_block(block_id)
        if parent_source_id is None:
            return self._entity_info(b.sources, block_id, NodeType.Source)
        else:
            srcs = b.find_sources(lambda s: s.id == parent_source_id)
            sources = []
            for src in srcs:
                sources.extend(get_subsources(src))
            return sources
        
    def guess_best_xdim(self, array):
        data_extent = array.shape
        if len(data_extent) > 2:
            print("Cannot handle more than 2D, sorry!")
            return None
        if len(data_extent) == 1:
            return 0

        d1 = array.dimensions[0]
        d2 = array.dimensions[1]

        if d1.dimension_type == nix.DimensionType.Sample:
            return 0
        elif d2.dimension_type == nix.DimensionType.Sample:
            return 1
        else:
            if (d1.dimension_type == nix.DimensionType.Set) and \
            (d2.dimension_type == nix.DimensionType.Range):
                return 1
            elif (d1.dimension_type == nix.DimensionType.Range) and \
                (d2.dimension_type == nix.DimensionType.Set):
                return 0
            else:
                return 0
            
    def suggested_plotter(self, array):
        if len(array.dimensions) > 3:
            print("cannot handle more than 3D")
            return None
        dim_types = [d.dimension_type for d in array.dimensions]
        dim_count = len(dim_types)
        if dim_count == 1:
            if dim_types[0] == nix.DimensionType.Sample:
                return PlotterTypes.LinePlotter
            elif dim_types[0] == nix.DimensionType.Range:
                if array.dimensions[0].is_alias:
                    return PlotterTypes.EventPlotter
                else:
                    return PlotterTypes.LinePlotter
            elif dim_types[0] == nix.DimensionType.Set:
                return PlotterTypes.CategoryPlotter
            else:
                return None
        elif dim_count == 2:
            if dim_types[0] == nix.DimensionType.Sample:
                if dim_types[1] == nix.DimensionType.Sample or \
                dim_types[1] == nix.DimensionType.Range:
                    return PlotterTypes.ImagePlotter
                else:
                    return PlotterTypes.LinePlotter
            elif dim_types[0] == nix.DimensionType.Range:
                if dim_types[1] == nix.DimensionType.Sample or \
                dim_types[1] == nix.DimensionType.Range:
                    return PlotterTypes.ImagePlotter
                else:
                    return PlotterTypes.LinePlotter
            elif dim_types[0] == nix.DimensionType.Set:
                if dim_types[1] == nix.DimensionType.Sample or \
                dim_types[1] == nix.DimensionType.Range:
                    return PlotterTypes.LinePlotter
                else:
                    return PlotterTypes.CategoryPlotter
            else:
                print("Sorry, not a supported combination of dimensions!")
                return None
        elif dim_count == 3:
            return PlotterTypes.ImagePlotter
        else:
            return None