import os
import sys
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
import pyqtgraph as pg

from PIL import Image, ImageOps
import numpy as np
import pandas as pd
import torch
import torchvision.transforms.functional as F

from IPython import embed
import pathlib

NONE_CHECKED_CHECKER_COLORS = ['red', 'black']

class ImageWithBbox(pg.GraphicsLayoutWidget):
# class ImageWithBbox(pg.ImageView):
    def __init__(self, img_path, parent=None):
        super(ImageWithBbox, self).__init__(parent)
        self.img_path = img_path
        self.label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt')

        # image setup
        self.imgItem = pg.ImageItem(ColorMap='viridis')
        self.plot_widget = self.addPlot(title="")
        self.plot_widget.addItem(self.imgItem, colorMap='viridis')
        self.plot_widget.scene().sigMouseClicked.connect(self.addROI)

        # self.imgItem.mouseClickEvent()

        self.imgItem.getViewBox().invertY(True)
        self.imgItem.getViewBox().setMouseEnabled(x=False, y=False)

        # get image and labels
        img = Image.open(img_path)
        img_gray = ImageOps.grayscale(img)
        self.img_array = np.array(img_gray).T
        self.imgItem.setImage(np.array(self.img_array))
        self.plot_widget.setYRange(0, self.img_array.shape[0], padding=0)
        self.plot_widget.setXRange(0, self.img_array.shape[1], padding=0)

        # label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt')
        self.labels = np.loadtxt(self.label_path, delimiter=' ')
        if len(self.labels) > 0 and len(self.labels.shape) == 1:
            self.labels = np.expand_dims(self.labels, 0)

        # add ROIS
        self.ROIs = []
        for enu, l in enumerate(self.labels):
            # x_center, y_center, width, height = l[1:] * self.img_array.shape[1]
            x_center = l[1] * self.img_array.shape[0]
            y_center = l[2] * self.img_array.shape[1]
            width = l[3] * self.img_array.shape[0]
            height = l[4] * self.img_array.shape[1]

            x0, y0, = x_center-width/2, y_center-height/2
            ROI = pg.RectROI((x0, y0), size=(width, height), removable=True, sideScalers=True)
            ROI.sigRemoveRequested.connect(self.removeROI)

            self.ROIs.append(ROI)
            self.plot_widget.addItem(ROI)

    def removeROI(self, roi):
        if roi in self.ROIs:
            self.ROIs.remove(roi)
            self.plot_widget.removeItem(roi)

    def addROI(self, event):
        # Check if the event is a double-click event
        if event.double():
            pos = event.pos()
            # Transform the mouse position to data coordinates
            pos_data = self.plot_widget.getViewBox().mapToView(pos)
            x, y = pos_data.x(), pos_data.y()
            # Create a new ROI at the double-clicked position
            new_ROI = pg.RectROI(pos=(x, y), size=(self.img_array.shape[0]*0.05, self.img_array.shape[1]*0.05), removable=True)
            new_ROI.sigRemoveRequested.connect(self.removeROI)
            self.ROIs.append(new_ROI)
            self.plot_widget.addItem(new_ROI)

class Bbox_correct_UI(QMainWindow):
    def __init__(self, data_path, parent=None):
        super(Bbox_correct_UI, self).__init__(parent)

        self.data_path = data_path
        self.files = sorted(list(pathlib.Path(self.data_path).absolute().rglob('*images/*.png')))

        self.close_without_saving_rois = False

        rec = QApplication.desktop().screenGeometry()
        height = rec.height()
        width = rec.width()
        self.resize(int(.8 * width), int(.8 * height))


        # widget and layout
        self.central_widget = QWidget(self)
        self.central_gridLayout = QGridLayout()
        self.central_Layout = QHBoxLayout()
        self.central_widget.setLayout(self.central_Layout)
        self.setCentralWidget(self.central_widget)

        ###########
        self.highlighted_label = None
        self.all_labels = []
        self.load_or_create_file_dict()

        new_name, new_file, new_checked = self.file_dict.iloc[0].values
        self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')

        # image widget
        self.current_img = ImageWithBbox(new_file, parent=self)
        self.central_Layout.addWidget(self.current_img, 4)

        # image select widget
        self.scroll = QScrollArea()             # Scroll Area which contains the widgets, set as the centralWidget
        self.widget = QWidget()                 # Widget that contains the collection of Vertical Box
        self.vbox = QVBoxLayout()

        for i in range(len(self.file_dict)):
            label = QLabel(f'{self.file_dict["name"][i]}')
            # label.setFrameShape(QLabel.Panel)
            # label.setFrameShadow(QLabel.Sunken)
            label.setAlignment(Qt.AlignRight)
            label.mousePressEvent = lambda event, label=label: self.label_clicked(label)
            # label.mousePressEvent = lambda event, label: self.label_clicked(label)

            if i == 0:
                label.setStyleSheet("border: 2px solid black; "
                                    "color : %s;" % (NONE_CHECKED_CHECKER_COLORS[self.file_dict['checked'][i]]))
                self.highlighted_label = label
            else:
                label.setStyleSheet("border: 1px solid gray; "
                                    "color : %s;" % (NONE_CHECKED_CHECKER_COLORS[self.file_dict['checked'][i]]))
            self.vbox.addWidget(label)
            self.all_labels.append(label)

        self.widget.setLayout(self.vbox)

        self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.scroll.setWidgetResizable(True)
        self.scroll.setWidget(self.widget)

        self.central_Layout.addWidget(self.scroll, 1)

        self.add_actions()

        self.init_MenuBar()

    def load_or_create_file_dict(self):
        csvs_im_data_path = list(pathlib.Path(self.data_path).absolute().rglob('*file_dict.csv'))
        if len(csvs_im_data_path) == 0:
            self.file_dict = pd.DataFrame(
                {'name': [f.name for f in self.files],
                 'files': self.files,
                 'checked': np.zeros(len(self.files), dtype=int)} # change this to locked
            )
        else:
            self.file_dict = pd.read_csv(csvs_im_data_path[0], sep=',')

    def save_file_dict(self):
        self.file_dict.to_csv(pathlib.Path(self.data_path)/'file_dict.csv', sep=',', index=False)

    def label_clicked(self, clicked_label):
        if self.highlighted_label:
            hl_mask = self.file_dict['name'] == self.highlighted_label.text()
            hl_label_name, _, hl_checked = self.file_dict[hl_mask].values[0]
            self.highlighted_label.setStyleSheet("border: 1px solid gray; "
                                                 "color: %s;" % (NONE_CHECKED_CHECKER_COLORS[hl_checked]))

        mask = self.file_dict['name'] == clicked_label.text()
        new_name, new_file, new_checked = self.file_dict[mask].values[0]
        clicked_label.setStyleSheet("border: 2px solid black;"
                                    "color: %s;" % (NONE_CHECKED_CHECKER_COLORS[new_checked]))
        self.highlighted_label = clicked_label

        self.switch_to_new_file(new_file, new_name)

    def lock_file(self):

        hl_mask = self.file_dict['name'] == self.highlighted_label.text()
        hl_label_name, _, hl_checked = self.file_dict[hl_mask].values[0]

        # ToDo: do everything with the index instead of mask
        df_idx = self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
        self.file_dict.at[df_idx, 'checked'] = 1

        self.highlighted_label.setStyleSheet("border: 1px solid gray; "
                                             "color: 'black';")

        new_idx = df_idx + 1 if df_idx < len(self.file_dict)-1 else 0
        new_name, new_file, new_checked = self.file_dict.iloc[new_idx].values

        self.all_labels[new_idx].setStyleSheet("border: 2px solid black;" 
                                               "color: %s;" % (NONE_CHECKED_CHECKER_COLORS[new_checked]))
        self.highlighted_label = self.all_labels[new_idx]

        self.switch_to_new_file(new_file, new_name)


    def switch_to_new_file(self, new_file, new_name):
        self.readout_rois()

        self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')

        self.central_Layout.removeWidget(self.current_img)
        self.current_img = ImageWithBbox(new_file, parent=self)
        self.central_Layout.insertWidget(0, self.current_img, 4)

    def readout_rois(self):
        if self.close_without_saving_rois:
            return
        new_labels = []
        for roi in self.current_img.ROIs:
            x0, y0 = roi.pos()
            x0 /= self.current_img.img_array.shape[0]
            y0 /= self.current_img.img_array.shape[1]

            w, h = roi.size()
            w /= self.current_img.img_array.shape[0]
            h /= self.current_img.img_array.shape[1]

            x_center = x0 + w/2
            y_center = y0 + h/2
            new_labels.append([1, x_center, y_center, w, h])
        new_labels = np.array(new_labels)
        np.savetxt(self.current_img.label_path, new_labels)

    def export_validated_data(self):
        fd = QFileDialog()
        export_path = fd.getExistingDirectory(self, 'Select Directory')
        if export_path:
            export_idxs = list(self.file_dict['files'][self.file_dict['checked'] == 1].index)
            keep_idxs = []
            for export_file_path, export_idx in zip(list(self.file_dict['files'][self.file_dict['checked'] == 1]), export_idxs):
                export_image_path = pathlib.Path(export_file_path)
                export_label_path = export_image_path.parent.parent / 'labels' / pathlib.Path(export_image_path.name).with_suffix('.txt')

                target_image_path = pathlib.Path(export_path) / 'images' / export_image_path.name
                target_label_path = pathlib.Path(export_path) / 'labels' / export_label_path.name
                if not target_image_path.exists():
                    os.rename(export_image_path, target_image_path)
                    os.rename(export_label_path, target_label_path)
                else:
                    keep_idxs.append(export_idx)


            # self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
            drop_idxs = list(set(export_idxs) - set(keep_idxs))
            self.file_dict = self.file_dict.drop(drop_idxs)
            self.save_file_dict()
            self.close_without_saving_rois = True
            self.close()

        else:
            print('nope')
            pass

    def open(self):
        pass

    def init_MenuBar(self):
        menubar = self.menuBar() # needs QMainWindow ?!
        file = menubar.addMenu('&File') # create file menu ... accessable with alt+F
        file.addActions([self.Act_open, self.Act_export, self.Act_exit])

        edit = menubar.addMenu('&Help')
        # edit.addActions([self.Act_undo, self.Act_unassigned_funds])

    def add_actions(self):
        self.lock_file_act = QAction('loc', self)
        self.lock_file_act.triggered.connect(self.lock_file)
        self.lock_file_act.setShortcut(Qt.Key_Space)
        self.addAction(self.lock_file_act)

        self.Act_open = QAction('&Open', self)
        # self.Act_open.setStatusTip('Open file')
        self.Act_open.triggered.connect(self.open) # ToDo: implement this fn

        self.Act_export = QAction('&Export', self)
        # self.Act_export.setStatusTip('Open file')
        self.Act_export.triggered.connect(self.export_validated_data)

        self.Act_exit = QAction('&Exit', self)  # trigger with alt+E
        self.Act_exit.setShortcut(Qt.Key_Q)
        self.Act_exit.triggered.connect(self.close)

    def closeEvent(self, *args, **kwargs):
        super(QMainWindow, self).closeEvent(*args, **kwargs)
        self.readout_rois()
        self.save_file_dict()

def main_UI():
    app = QApplication(sys.argv)  # create application
    data_path = sys.argv[1]
    w = Bbox_correct_UI(data_path)  # create window
    w.show()
    sys.exit(app.exec_())

if __name__ == '__main__':
    main_UI()