import os
from PyQt5.QtCore import QAbstractItemModel, QModelIndex, Qt, QSize
from PyQt5.QtGui import QIcon
from PyQt5.QtWidgets import QTreeView, QTreeWidgetItem, QAbstractItemView, QHeaderView

import nixview.constants as cnst
from nixview.util.file_handler import NodeType
from nixview.util.descriptors import ItemDescriptor
from nixview.util.enums import TreeType

column_names = ['Name', 'Type', 'Value', 'Description', 'ID',]


class NixTreeItem(QTreeWidgetItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(parent)
        self._node_descriptor = node_descriptor
        self._file_handler = file_handler
        
        self._parent_item = parent
        self._child_items = []
        self._is_loaded = False
   
    @property
    def node_descriptor(self):
        return self._node_descriptor
    
    @property
    def entity_type(self):
        return self._node_descriptor.entity_type
    
    def child(self, row):
        if row < len(self._child_items):
            return self._child_items[row]
        return None

    def childCount(self):
        n = 0 if self._is_loaded else 1
        return max(n, len(self._child_items))

    def columnCount(self):
        return len(column_names)

    def data(self, column):
        if column == 0:
            return self._node_descriptor.name
        elif column == 1:
            return self._node_descriptor.type
        elif column == 2:
            return self._node_descriptor.value
        elif column == 3:
            return self._node_descriptor.definition
        elif column == 4:
            return self._node_descriptor.id
        else:
            return None

    def parent(self):
        return self._parent_item

    def row(self):
        if self._parent_item:
            return self._parent_item._child_items.index(self)
        return 0


class FileTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler=file_handler, parent=parent)
        self._is_loaded = False
        
    def load_children(self):
        self._child_items = []
        sections, _ = self._file_handler.request_metadata()
        for s in sections:
            self._child_items.append(SectionTreeItem(s, self._file_handler, parent=self))
        blocks = self._file_handler.request_blocks()
        for b in blocks:
            self._child_items.append(BlockTreeItem(b, self._file_handler, parent=self))
        self._is_loaded = True


class DataTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler=file_handler, parent=parent)
        self._is_loaded = False
        
    def load_children(self):
        self._child_items = []
        blocks = self._file_handler.request_blocks()
        for b in blocks:
            self._child_items.append(BlockTreeItem(b, self._file_handler, parent=self))
        self._is_loaded = True


class MetadataTreeItem(NixTreeItem):
    """
        Root item for a metadata tree.
        
        Args:
        
    """
    def __init__(self, node_descriptor, file_handler, parent=None, root_section_id=None):
        super().__init__(node_descriptor, file_handler=file_handler, parent=parent)
        self._root_section_id = root_section_id
        self._is_loaded = False
        
    def load_children(self):
        self._child_items = []
        if self._root_section_id is not None:
            item = self._file_handler.request_section_descriptor(self._root_section_id)
            if item is not None:
                self._child_items.append(SectionTreeItem(item, self._file_handler, parent=self))
        else:
            sections, _ = self._file_handler.request_metadata(root_id=self._root_section_id)
            for s in sections:
                self._child_items.append(SectionTreeItem(s, self._file_handler, parent=self))
        self._is_loaded = True


class BlockTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self.setFlags(Qt.ItemIsEnabled | Qt.ItemIsSelectable |Qt.ItemIsEditable)
        self._is_loaded = False
    
    def load_children(self):
        self._child_items = []
        arrays = self._file_handler.request_data_arrays(self._node_descriptor.id)
        for a in arrays:
            self._child_items.append(DataArrayTreeItem(a, self._file_handler, parent=self))
            
        for t in self._file_handler.request_tags(self._node_descriptor.id):
            self._child_items.append(TagTreeItem(t, self._file_handler, parent=self))
            
        data_frames = self._file_handler.request_data_frames(self._node_descriptor.id)
        for df in data_frames:
            self._child_items.append(DataFrameTreeItem(df, self._file_handler, parent=self))
        
        sources = self._file_handler.request_sources(self._node_descriptor.id)
        for s in sources:
            self._child_items.append(SourceTreeItem(s, self._file_handler, parent=self))
        self._is_loaded = True


class SourceTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self._is_loaded = False
        
    def load_children(self):
        sources = self._file_handler.request_sources(self._node_descriptor.block_id, self._node_descriptor.id)
        for s in sources:
            self._child_items.append(SourceTreeItem(s, self._file_handler, parent=self))
        self._is_loaded = True
        

class GroupTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self._is_loaded = False
    
    def load_children(self):
        self._child_items = []
        arrays = self._file_handler.request_data_arrays(self._node_descriptor.id)
        for a in arrays:
            self._child_items.append(DataArrayTreeItem(a, self._file_handler, parent=self))
            
        for t in self._file_handler.request_tags(self._node_descriptor.id):
            self._child_items.append(TagTreeItem(t, self._file_handler, parent=self))
            
        data_frames = self._file_handler.request_data_frames(self._node_descriptor.id)
        for df in data_frames:
            self._child_items.append(DataFrameTreeItem(df, self._file_handler, parent=self))
        
        sources = self._file_handler.request_sources(self._node_descriptor.id)
        for s in sources:
            self._child_items.append(SourceTreeItem(s, self._file_handler, parent=self))
        self._is_loaded = True


class DataFrameTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self._is_loaded = True


class FeatureTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self._is_loaded = True


class DataArrayTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self._is_loaded = False
    
    def load_children(self):
        self._child_items = []
        dimensions = self._file_handler.request_dimensions(self._node_descriptor.block_id, self._node_descriptor.id)
        for d in dimensions:
            self._child_items.append(DimensionTreeItem(d, self._file_handler, parent=self))
        self._is_loaded = True


class DimensionTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self._is_loaded = True


class SectionTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler=file_handler, parent=parent)
        self.setFlags(Qt.ItemIsSelectable)

        self._is_loaded = False
        
    def load_children(self):
        self._child_items = []
        sections, properties = self._file_handler.request_metadata(self._node_descriptor.id)
        for s in sections:
            self._child_items.append(SectionTreeItem(s, self._file_handler, parent=self))
        for p in properties:
            self._child_items.append(PropertyTreeItem(p, self._file_handler, parent=self))
        self._is_loaded = True


class PropertyTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self._is_loaded = True
    
    def childCount(self):
        return 0


class TagTreeItem(NixTreeItem):
    def __init__(self, node_descriptor, file_handler, parent=None):
        super().__init__(node_descriptor, file_handler, parent=parent)
        self._is_loaded = False

    def load_children(self):
        self._child_items = []
        references = self._file_handler.request_references(self._node_descriptor.block_id, self._node_descriptor.id, self._node_descriptor.entity_type == NodeType.MultiTag)
        for r in references:
            self._child_items.append(DataArrayTreeItem(r, self._file_handler, self))
        
        features = self._file_handler.request_features(self._node_descriptor.block_id, self._node_descriptor.id, self._node_descriptor.entity_type == NodeType.MultiTag)
        for f in features:
            self._child_items.append(FeatureTreeItem(f, self._file_handler, self))
        self._is_loaded = True
        

class TreeModel(QAbstractItemModel):
    
    def __init__(self, file_handler, tree_type=TreeType.Full, parent=None, root_section_id=None):
        super(TreeModel, self).__init__(parent)
        nd = ItemDescriptor(file_handler.filename, type="Root item")
        self.type_icons = {NodeType.Block: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_block_1d.png")),
                           NodeType.Source: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_source.png")),
                           NodeType.DataArray: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_data_array.png")),
                           NodeType.Dimension: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_dimension.png")),
                           NodeType.DataFrame: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_data_frame.png")),
                           NodeType.Section: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_section.png")),
                           NodeType.Property: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_property.png")),
                           NodeType.Tag: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_tag.png")),
                           NodeType.MultiTag: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_tag.png")),
                           NodeType.Group: QIcon(os.path.join(cnst.ICONS_FOLDER, "icons/nix_group.png")),
                           NodeType.Feature: QIcon(os.path.join(cnst.ICONS_FOLDER, "nix_feature.png"))}

        if tree_type == TreeType.Full:
            self.root_item = FileTreeItem(nd, file_handler, parent=None)
        elif tree_type == TreeType.Metadata:
            self.root_item = MetadataTreeItem(nd, file_handler, parent=None, root_section_id=root_section_id)
        else:
            self.root_item = DataTreeItem(nd, file_handler, parent=None)
        self.root_item.load_children()

    def columnCount(self, parent):
        return len(column_names)

    def data(self, index, role):
        if not index.isValid():
            return None
        item = index.internalPointer()
      
        if role == Qt.DisplayRole:
            return item.data(index.column())
        elif role == Qt.DecorationRole and index.column() == 0:
            if item.entity_type in self.type_icons.keys():
                return self.type_icons[item.entity_type]
            else:
                return None
        else:
            return None

    def canFetchMore(self, index):
        if not index.isValid():
            return False
        item = index.internalPointer()
        return not item._is_loaded

    def fetchMore(self, index):
        item = index.internalPointer()
        item.load_children()

    def flags(self, index):
        if not index.isValid():
            return Qt.NoItemFlags

        return Qt.ItemIsEnabled | Qt.ItemIsUserCheckable
    
    def headerData(self, section, orientation, role):
        if orientation == Qt.Horizontal and role == Qt.DisplayRole:
            return column_names[section]
        return None

    def index(self, row, column, parent_index):
        idx = QModelIndex()
        
        if not self.hasIndex(row, column, parent_index):
            return idx

        if not parent_index.isValid():
            parentItem = self.root_item
        else:
            parentItem = parent_index.internalPointer()

        childItem = parentItem.child(row)
        if childItem:
            idx = self.createIndex(row, column, childItem)
        return idx

    def parent(self, child_index):
        if not child_index.isValid():
            return QModelIndex()

        child_item = child_index.internalPointer()
        parent_item = child_item.parent()
        if parent_item is None:
            return QModelIndex()
            # return self.createIndex(0, 0, self.root_item)

        return self.createIndex(parent_item.row(), 0, parent_item)

    def rowCount(self, parent_index):
        if parent_index.column() > 0:
            return 0

        if not parent_index.isValid():
            parentItem = self.root_item
        else:
            parentItem = parent_index.internalPointer()

        return parentItem.childCount()


class NixTreeView(QTreeView):
    icon_size = QSize(30, 30)

    def __init__(self, parent=None) -> None:
        super().__init__(parent=parent)
        self.expanded.connect(self.columnResize)
        self.collapsed.connect(self.columnResize)
        self.setAlternatingRowColors(True)
        self.setUniformRowHeights(True)  # Allows for scrolling optimizations.
        self.setWindowTitle("Data Tree")
        self.setIconSize(self.icon_size)
        self.setSelectionBehavior(QAbstractItemView.SelectItems)
        self.setSelectionMode(QAbstractItemView.SingleSelection)

        header = self.header()
        header.setStretchLastSection(True)
        header.setFirstSectionMovable(False)
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
                
    def columnResize(self, index):
        for i in range(len(column_names)):
            self.resizeColumnToContents(i)