import os 
import os 

import numpy as np
import matplotlib.pyplot as plt 

from IPython import embed
from pandas import read_csv
from modules.logger import makeLogger
from scipy.ndimage import gaussian_filter1d

logger = makeLogger(__name__)

class Behavior:
    """Load behavior data from csv file as class attributes
        Attributes
    ----------
    behavior: 0: chasing onset, 1: chasing offset, 2: physical contact
    behavior_type:         
    behavioral_category:   
    comment_start:         
    comment_stop:          
    dataframe: pandas dataframe with all the data            
    duration_s:             
    media_file:            
    observation_date:      
    observation_id:        
    start_s: start time of the event in seconds               
    stop_s:  stop time of the event in seconds               
    total_length:          
    """

    def __init__(self, folder_path: str) -> None:
        

        LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True)
        self.time = np.load(os.path.join(folder_path, "times.npy"), allow_pickle=True)
        csv_filename = [f for f in os.listdir(folder_path) if f.endswith('.csv')][0] # check if there are more than one csv file
        self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
        self.chirps = np.load(os.path.join(folder_path, 'chirps.npy'), allow_pickle=True)
        self.chirps_ids = np.load(os.path.join(folder_path, 'chirps_ids.npy'), allow_pickle=True)

        for k, key in enumerate(self.dataframe.keys()):
            key = key.lower() 
            if ' ' in key:
                key = key.replace(' ', '_')
                if '(' in key:
                    key = key.replace('(', '')
                    key = key.replace(')', '')
            setattr(self, key, np.array(self.dataframe[self.dataframe.keys()[k]]))
        
        last_LED_t_BORIS = LED_on_time_BORIS[-1]
        real_time_range = self.time[-1] - self.time[0]
        factor = 1.034141
        shift = last_LED_t_BORIS - real_time_range * factor
        self.start_s = (self.start_s - shift) / factor
        self.stop_s = (self.stop_s - shift) / factor
  
"""
1 - chasing onset
2 - chasing offset
3 - physical contact event

temporal encpding needs to be corrected ... not exactly 25FPS.

### correspinding python code ###

    factor = 1.034141
    LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True)
    last_LED_t_BORIS = LED_on_time_BORIS[-1]
    real_time_range = times[-1] - times[0]
    shift = last_LED_t_BORIS - real_time_range * factor

    data = pd.read_csv(os.path.join(folder_path, file[1:-7] + '.csv'))
    boris_times = data['Start (s)']
    data_times = []

    for Cevent_t in boris_times:
        Cevent_boris_times = (Cevent_t - shift) / factor
        data_times.append(Cevent_boris_times)

    data_times = np.array(data_times)
    behavior = data['Behavior']
"""

def correct_chasing_events(
    category: np.ndarray, 
    timestamps: np.ndarray
    ) -> tuple[np.ndarray, np.ndarray]:

    onset_ids = np.arange(
        len(category))[category == 0]
    offset_ids = np.arange(
        len(category))[category == 1]

    # Check whether on- or offset is longer and calculate length difference
    if len(onset_ids) > len(offset_ids):
        len_diff = len(onset_ids) - len(offset_ids)
        longer_array = onset_ids
        shorter_array = offset_ids
        logger.info(f'Onsets are greater than offsets by {len_diff}')
    elif len(onset_ids) < len(offset_ids):
        len_diff = len(offset_ids) - len(onset_ids)
        longer_array = offset_ids
        shorter_array = onset_ids
        logger.info(f'Offsets are greater than offsets by {len_diff}')
    elif len(onset_ids) == len(offset_ids):
        logger.info('Chasing events are equal')
        return category, timestamps

    # Correct the wrong chasing events; delete double events
    wrong_ids = []
    for i in range(len(longer_array)-(len_diff+1)):
        if (shorter_array[i] > longer_array[i]) & (shorter_array[i] < longer_array[i+1]):
            pass
        else:
            wrong_ids.append(longer_array[i])
            longer_array = np.delete(longer_array, i)
        
    category = np.delete(
        category, wrong_ids)
    timestamps = np.delete(
        timestamps, wrong_ids)
    return category, timestamps


def event_triggered_chirps(
    event: np.ndarray, 
    chirps:np.ndarray,
    time_before_event: int,
    time_after_event: int
    )-> tuple[np.ndarray, np.ndarray]:


    event_chirps = []   # chirps that are in specified window around event
    centered_chirps = []    # timestamps of chirps around event centered on the event timepoint

    for event_timestamp in event:
        start = event_timestamp - time_before_event    # timepoint of window start
        stop = event_timestamp + time_after_event    # timepoint of window ending
        chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)]     # get chirps that are in a -5 to +5 sec window around event
        event_chirps.append(chirps_around_event)
        if len(chirps_around_event) == 0:
            continue
        else: 
            centered_chirps.append(chirps_around_event - event_timestamp)
    centered_chirps = np.concatenate(centered_chirps, axis=0)   # convert list of arrays to one array for plotting

    return event_chirps, centered_chirps


def main(datapath: str):

    # behavior is pandas dataframe with all the data
    bh = Behavior(datapath)
    
    # chirps are not sorted in time (presumably due to prior groupings)
    # get and sort chirps and corresponding fish_ids of the chirps
    chirps = bh.chirps[np.argsort(bh.chirps)]
    chirps_fish_ids = bh.chirps_ids[np.argsort(bh.chirps)]
    category = bh.behavior
    timestamps = bh.start_s

    # Correct for doubles in chasing on- and offsets to get the right on-/offset pairs
    # Get rid of tracking faults (two onsets or two offsets after another)
    category, timestamps = correct_chasing_events(category, timestamps)

    # split categories
    chasing_onset = timestamps[category == 0]
    chasing_offset = timestamps[category == 1]
    physical_contact = timestamps[category == 2]

    # First overview plot
    fig1, ax1 = plt.subplots()
    ax1.scatter(chirps, np.ones_like(chirps), marker='*', color='royalblue', label='Chirps')
    ax1.scatter(chasing_onset, np.ones_like(chasing_onset)*2, marker='.', color='forestgreen', label='Chasing onset')
    ax1.scatter(chasing_offset, np.ones_like(chasing_offset)*2.5, marker='.', color='firebrick', label='Chasing offset')
    ax1.scatter(physical_contact, np.ones_like(physical_contact)*3, marker='x', color='black', label='Physical contact')
    plt.legend()
    # plt.show()
    plt.close()

    # Get fish ids
    fish_ids = np.unique(chirps_fish_ids)

    ##### Chasing triggered chirps CTC #####
    # Evaluate how many chirps were emitted in specific time window around the chasing onset events

    # Iterate over chasing onsets (later over fish)
    time_around_event = 5    # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event
    #### Loop crashes at concatenate in function ####
    # for i in range(len(fish_ids)):
    #     fish = fish_ids[i]
    #     chirps = chirps[chirps_fish_ids == fish]
    #     print(fish)

    chasing_chirps, centered_chasing_chirps = event_triggered_chirps(chasing_onset, chirps, time_around_event, time_around_event)
    physical_chirps, centered_physical_chirps = event_triggered_chirps(physical_contact, chirps, time_around_event, time_around_event)

    # Kernel density estimation ???
    # centered_chasing_chirps_convolved = gaussian_filter1d(centered_chasing_chirps, 5)
    
    # centered_chasing = chasing_onset[0] - chasing_onset[0]   ## get the 0 timepoint for plotting; set one chasing event to 0
    offsets = [0.5, 1]
    fig4, ax4 = plt.subplots(figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True)
    ax4.eventplot(np.array([centered_chasing_chirps, centered_physical_chirps]), lineoffsets=offsets, linelengths=0.25, colors=['g', 'r'])
    ax4.vlines(0, 0, 1.5, 'tab:grey', 'dashed', 'Timepoint of event')
    # ax4.plot(centered_chasing_chirps_convolved)
    ax4.set_yticks(offsets)
    ax4.set_yticklabels(['Chasings', 'Physical \n contacts'])
    ax4.set_xlabel('Time[s]')
    ax4.set_ylabel('Type of event')
    plt.show()

    # Associate chirps to inidividual fish
    fish1 = chirps[chirps_fish_ids == fish_ids[0]]
    fish2 = chirps[chirps_fish_ids == fish_ids[1]]
    fish = [len(fish1), len(fish2)]

    ### Plots:
    # 1. All recordings, all fish, all chirps
        # One CTC, one PTC
    # 2. All recordings, only winners
        # One CTC, one PTC
    # 3. All recordings, all losers
        # One CTC, one PTC

    #### Chirp counts per fish general #####
    fig2, ax2 = plt.subplots()
    x = ['Fish1', 'Fish2']
    width = 0.35
    ax2.bar(x, fish, width=width)
    ax2.set_ylabel('Chirp count')
    # plt.show()
    plt.close()

 
    ##### Count chirps emitted during chasing events and chirps emitted out of chasing events #####
    chirps_in_chasings = []
    for onset, offset in zip(chasing_onset, chasing_offset):
        chirps_in_chasing = [c for c in chirps if (c > onset) & (c < offset)]
        chirps_in_chasings.append(chirps_in_chasing)

    # chirps out of chasing events
    counts_chirps_chasings = 0
    chasings_without_chirps = 0
    for i in chirps_in_chasings:
        if i:
            chasings_without_chirps += 1
        else:
            counts_chirps_chasings += 1

    # chirps in chasing events
    fig3 , ax3 = plt.subplots()
    ax3.bar(['Chirps in chasing events',  'Chasing events without Chirps'], [counts_chirps_chasings, chasings_without_chirps], width=width)
    plt.ylabel('Count')
    # plt.show()
    plt.close()  

    # comparison between chasing events with and without chirps


    
    embed()
    exit()



if __name__ == '__main__':
    # Path to the data
    datapath = '../data/mount_data/2020-05-13-10_00/'
    datapath = '../data/mount_data/2020-05-13-10_00/'
    main(datapath)