From 6fd3323c6cd2304a968e3d310ef0898e029a3693 Mon Sep 17 00:00:00 2001
From: sprause <sprause95@gmail.com>
Date: Wed, 25 Jan 2023 09:11:51 +0100
Subject: [PATCH] inserted two functions from CTCPTC stuff

---
 code/modules/behaviour_handling.py | 46 ++++++++++++++++++++++++++----
 1 file changed, 41 insertions(+), 5 deletions(-)

diff --git a/code/modules/behaviour_handling.py b/code/modules/behaviour_handling.py
index 90a18ab..6641309 100644
--- a/code/modules/behaviour_handling.py
+++ b/code/modules/behaviour_handling.py
@@ -8,6 +8,7 @@ from IPython import embed
 
 from pandas import read_csv
 from modules.logger import makeLogger
+from modules.datahandling import causal_kde1d, acausal_kde1d
 
 
 logger = makeLogger(__name__)
@@ -76,14 +77,14 @@ def correct_chasing_events(
     offset_ids = np.arange(
         len(category))[category == 1]
 
-    woring_bh = np.arange(len(category))[category!=2][:-1][np.diff(category[category!=2])==0]
+    wrong_bh = np.arange(len(category))[category!=2][:-1][np.diff(category[category!=2])==0]
     if onset_ids[0] > offset_ids[0]:
         offset_ids = np.delete(offset_ids, 0)
         help_index = offset_ids[0]
-        woring_bh = np.append(woring_bh, help_index)
+        wrong_bh = np.append(wrong_bh[help_index])
 
-    category = np.delete(category, woring_bh)
-    timestamps = np.delete(timestamps, woring_bh)
+    category = np.delete(category, wrong_bh)
+    timestamps = np.delete(timestamps, wrong_bh)
 
     # Check whether on- or offset is longer and calculate length difference
     if len(onset_ids) > len(offset_ids):
@@ -94,6 +95,41 @@ def correct_chasing_events(
         logger.info(f'Offsets are greater than onsets by {len_diff}')
     elif len(onset_ids) == len(offset_ids):
         logger.info('Chasing events are equal')
+    
+    return category, timestamps
 
 
-    return category, timestamps
\ No newline at end of file
+def event_triggered_chirps(
+    event: np.ndarray, 
+    chirps:np.ndarray,
+    time_before_event: int,
+    time_after_event: int,
+    dt: float,
+    width: float,
+    )-> tuple[np.ndarray, np.ndarray, np.ndarray]:
+
+    event_chirps = []   # chirps that are in specified window around event
+    centered_chirps = []    # timestamps of chirps around event centered on the event timepoint
+
+    for event_timestamp in event:
+        start = event_timestamp - time_before_event
+        stop = event_timestamp + time_after_event
+        chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)]
+        event_chirps.append(chirps_around_event)
+        if len(chirps_around_event) == 0:
+            continue
+        else: 
+            centered_chirps.append(chirps_around_event - event_timestamp)
+    
+    time = np.arange(-time_before_event, time_after_event, dt)
+    
+    # Kernel density estimation with some if's
+    if len(centered_chirps) == 0:
+        centered_chirps = np.array([])
+        centered_chirps_convolved = np.zeros(len(time))
+    else:
+        centered_chirps = np.concatenate(centered_chirps, axis=0)   # convert list of arrays to one array for plotting
+        centered_chirps_convolved = (acausal_kde1d(centered_chirps, time, width)) / len(event)
+
+    return event_chirps, centered_chirps, centered_chirps_convolved
+