from PyQt5.QtCore import QAbstractItemModel, QFile, QModelIndex, Qt
from PyQt5.QtWidgets import QHeaderView, QTreeWidgetItem
from collections import OrderedDict
from enum import Enum

from IPython import embed

column_names = ['Name', 'Type', 'ID', 'Value', 'Description']

class NodeType(Enum):
     Root = "root"
     Section = "section"
     Block = "block"
     DataArray = "data_array"
     Property = "property"
     Dimension = "dimension"
     Tag = "tag"


class NixTreeItem():
    def __init__(self, name, item_type, description, file_handler, id=None, node_type=None, value=None, parent=None):
        self._name = name
        self._item_type = item_type
        self._description = description
        self._id = id
        self._node_type = node_type
        self._value = value
        self._file_handler = file_handler
        
        self._parent_item = parent
        self._child_items = []
        self._is_loaded = False

    def child(self, row):
        if row < len(self._child_items):
            return self._child_items[row]
        return None

    def childCount(self):
        n = 0 if self._is_loaded else 1
        return max(n, len(self._child_items))

    def columnCount(self):
        return len(column_names)

    def data(self, column):
        if column == 0:
            return self._name
        elif column == 1:
            return self._item_type
        elif column == 2:
            return self._id
        elif column == 3:
            return self._value
        elif column == 4:
            return self._description
        else:
            return None

    def parent(self):
        return self._parent_item

    def row(self):
        if self._parent_item:
            return self._parent_item._child_items.index(self)
        return 0


class FileTreeItem(NixTreeItem):
    def __init__(self, name, item_type, description, file_handler, id=None, node_type=NodeType.Root, value=None, parent=None):
        super().__init__(name, item_type, description, file_handler=file_handler, id=id, node_type=node_type, value=value, parent=parent)
        self._is_loaded = False
        
    def load_children(self):
        self._child_items = []
        sections, _ = self._file_handler.request_metadata()
        for s in sections:
            self._child_items.append(SectionTreeItem(s["name"], s["item_type"], s["description"], self._file_handler, id=s["id"], parent=self))
        blocks = self._file_handler.request_blocks()
        for b in blocks:
            self._child_items.append(BlockTreeItem(b["name"], b["item_type"], b["description"], self._file_handler, id=b["id"], parent=self))
        self._is_loaded = True


class BlockTreeItem(NixTreeItem):
    def __init__(self, name, item_type, description, file_handler, id=None, node_type=NodeType.Block, value=None, parent=None):
        super().__init__(name, item_type, description, file_handler, id=id, node_type=node_type, value=value, parent=parent)
        self._is_loaded = False
    
    def load_children(self):
        self._child_items = []
        arrays = self._file_handler.request_data_arrays(self._id)
        for a in arrays:
            self._child_items.append(DataArrayTreeItem(a["name"], a["item_type"], a["description"], self._file_handler, id=a["id"], parent=self))
            
        for t in self._file_handler.request_tags(self._id):
            self._child_items.append(TagTreeItem(t["name"], t["item_type"], t["description"], self._file_handler, id=t["id"], parent=self))
        self._is_loaded = True


class DataArrayTreeItem(NixTreeItem):
    def __init__(self, name, item_type, description, file_handler, id, node_type=NodeType.DataArray, value=None, parent=None):
        super().__init__(name, item_type, description, file_handler, id=id, node_type=node_type, value=value, parent=parent)
        self._is_loaded = False
    
    def load_children(self):
        dimensions = self._file_handler.request_dimensions(self._parent_item._id, self._id)
        for d in dimensions:
            self._child_items.append(DimensionTreeItem(d["name"], d["item_type"], d["description"], self._file_handler, id=d["id"], parent=self))
        self._is_loaded = True

class DimensionTreeItem(NixTreeItem):
    def __init__(self, name, item_type, description, file_handler, id, node_type=NodeType.Dimension, value=None, parent=None):
        super().__init__(name, item_type, description, file_handler, id=id, node_type=node_type, value=value, parent=parent)
        self._is_loaded = True

class SectionTreeItem(NixTreeItem):
    def __init__(self, name, item_type, description, file_handler, id=None, node_type=NodeType.Section, value=None, parent=None):
        super().__init__(name, item_type, description, file_handler=file_handler, id=id, node_type=node_type, value=value, parent=parent)
        self._is_loaded = False
        
    def load_children(self):
        print("Load Children", self._name, self._id)
        self._child_items = []
        sections, properties = self._file_handler.request_metadata(self._id)
        for s in sections:
            self._child_items.append(SectionTreeItem(s["name"], s["item_type"], s["description"], self._file_handler, id=s["id"], parent=self))
        for p in properties:
            self._child_items.append(PropertyTreeItem(p["name"], "", p["unit"], self._file_handler, id=p["id"],value="unset", parent=self))
        self._is_loaded = True

class PropertyTreeItem(NixTreeItem):
    def __init__(self, name, item_type, description, file_handler, id=None, node_type=NodeType.Property, value=None, parent=None):
        super().__init__(name, item_type, description, file_handler, id=id, node_type=node_type, value=value, parent=parent)
        self._is_loaded = True
    
    def childCount(self):
        return 0

class TagTreeItem(NixTreeItem):
    def __init__(self, name, item_type, description, file_handler, id, node_type=NodeType.Tag, value=None, parent=None):
        super().__init__(name, item_type, description, file_handler, id=id, node_type=node_type, value=value, parent=parent)
        self._is_loaded = False

    def load_children(self):
        
        self._is_loaded = True
        
class TreeModel(QAbstractItemModel):
    def __init__(self, file_handler, parent=None):
        super(TreeModel, self).__init__(parent)
        self.root_item = FileTreeItem(file_handler.filename, "Root Item", "", file_handler, parent=None)
        self.root_item.load_children()

    def columnCount(self, parent):
        return len(column_names)

    def data(self, index, role):
        if not index.isValid():
            return None

        item = index.internalPointer()

        if role == Qt.DisplayRole:
            return item.data(index.column())
        else:
            return None

    def canFetchMore(self, index):
        if not index.isValid():
            return False
        item = index.internalPointer()
        return not item._is_loaded

    def fetchMore(self, index):
        item = index.internalPointer()
        item.load_children()

    def flags(self, index):
        if not index.isValid():
            return Qt.NoItemFlags

        return Qt.ItemIsEnabled | Qt.ItemIsUserCheckable
    
    def headerData(self, section, orientation, role):
        if orientation == Qt.Horizontal and role == Qt.DisplayRole:
            return column_names[section]
        return None

    def index(self, row, column, parent_index):
        idx = QModelIndex()
        
        if not self.hasIndex(row, column, parent_index):
            return idx

        if not parent_index.isValid():
            parentItem = self.root_item
        else:
            parentItem = parent_index.internalPointer()

        childItem = parentItem.child(row)
        if childItem:
            idx = self.createIndex(row, column, childItem)
        return idx

    def parent(self, child_index):
        if not child_index.isValid():
            return QModelIndex()

        child_item = child_index.internalPointer()
        parent_item = child_item.parent()
        if parent_item is None:
            return QModelIndex()
            # return self.createIndex(0, 0, self.root_item)

        return self.createIndex(parent_item.row(), 0, parent_item)

    def rowCount(self, parent_index):
        if parent_index.column() > 0:
            return 0

        if not parent_index.isValid():
            parentItem = self.root_item
        else:
            parentItem = parent_index.internalPointer()

        return parentItem.childCount()