From 92275cfab9a1dbf66f9668068c81de0c15cbd4aa Mon Sep 17 00:00:00 2001
From: tillraab <till.raab@student.uni-tuebingen.de>
Date: Mon, 11 Jul 2022 16:29:58 +0200
Subject: [PATCH] files to analyse competition experiments

---
 LED_detect.py     |  8 ++---
 eval_LED.py       |  2 +-
 event_videos.py   | 92 +++++++++++++++++++++++++++++++++++++----------
 trail_analysis.py | 67 +++++++++++++++++++++++++---------
 4 files changed, 129 insertions(+), 40 deletions(-)

diff --git a/LED_detect.py b/LED_detect.py
index 9d16ed1..a8f34db 100644
--- a/LED_detect.py
+++ b/LED_detect.py
@@ -72,10 +72,10 @@ def main(file_path, check, x, y):
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Detect frames of blinking LED in video recordings.')
     parser.add_argument('file', type=str, help='video file to be analyzed')
-    parser.add_argument("-c", action="store_true", help="check if LED pos is correct")
-    parser.add_argument('-x', type=int, nargs=2, default=[1272, 1282], help='x-borders of LED detect area (in pixels)')
-    parser.add_argument('-y', type=int, nargs=2, default=[1500, 1516], help='y-borders of LED area (in pixels)')
+    parser.add_argument("-c", '--check', action="store_true", help="check if LED pos is correct")
+    parser.add_argument('-x', type=int, nargs=2, default=[1240, 1250], help='x-borders of LED detect area (in pixels)')
+    parser.add_argument('-y', type=int, nargs=2, default=[1504, 1526], help='y-borders of LED area (in pixels)')
     args = parser.parse_args()
     import glob
 
-    main(args.file, args.c, args.x, args.y)
+    main(args.file, args.check, args.x, args.y)
diff --git a/eval_LED.py b/eval_LED.py
index e990a3a..fdd7111 100644
--- a/eval_LED.py
+++ b/eval_LED.py
@@ -15,7 +15,7 @@ def main(folder):
     frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
     times = np.load(os.path.join(folder, 'times.npy'))
-    LED_idx = pd.read_csv(os.path.join(folder, 'led_idxs.csv'), sep=',')
+    LED_idx = pd.read_csv(os.path.join(folder, 'led_idxs.csv'), sep=',', encoding = "utf-7")
 
     led_idx = np.array(LED_idx).T[0]
     led_frame = np.load(os.path.join(folder, 'LED_frames.npy'))
diff --git a/event_videos.py b/event_videos.py
index 1466e03..372106b 100644
--- a/event_videos.py
+++ b/event_videos.py
@@ -8,6 +8,7 @@ import glob
 import argparse
 from IPython import embed
 from tqdm import tqdm
+from thunderfish.powerspectrum import decibel
 
 def main(folder, dt):
     video_path = glob.glob(os.path.join(folder, '2022*.mp4'))[0]
@@ -16,21 +17,23 @@ def main(folder, dt):
         os.mkdir(create_video_path)
     video = cv2.VideoCapture(video_path) #  was 'cap'
 
-    fish_freqs = np.load(os.path.join(folder, 'analysis', 'fish_freq_interp.npy'))
-    max_freq, min_freq = np.nanmax(fish_freqs), np.nanmin(fish_freqs)
+    # fish_freqs = np.load(os.path.join(folder, 'analysis', 'fish_freq_interp.npy'))
+    fish_freqs = np.load(os.path.join(folder, 'analysis', 'fish_freq.npy'))
     rise_idx = np.load(os.path.join(folder, 'analysis', 'rise_idx.npy'))
     frame_times = np.load(os.path.join(folder, 'analysis', 'frame_times.npy'))
     times = np.load(os.path.join(folder, 'times.npy'))
+
+    fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy'))
+    fill_times = np.load(os.path.join(folder, 'fill_times.npy'))
+    fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy'))
+    fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r',
+                               shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F')
     #######################################
     for fish_nr in np.arange(2)[::-1]:
 
         for idx_oi in tqdm(np.array(rise_idx[fish_nr][~np.isnan(rise_idx[fish_nr])], dtype=int)):
-            # idx_oi = int(rise_idx[1][10])
             time_oi = times[idx_oi]
 
-            # embed()
-            # quit()
-
             HH = int((time_oi / 3600) // 1)
             MM = int((time_oi - HH * 3600) // 60)
             SS =  int(time_oi - HH * 3600 - MM * 60)
@@ -38,40 +41,91 @@ def main(folder, dt):
             frames_oi = np.arange(len(frame_times))[np.abs(frame_times - time_oi) <= dt]
             idxs_oi = np.arange(len(times))[np.abs(times - time_oi) <= dt*3]
 
-            fig = plt.figure(figsize=(20/2.54, 20/2.54))
-            gs = gridspec.GridSpec(2, 1, left=0.1, bottom = 0.1, right=0.95, top=0.95, height_ratios=(4, 1))
+            fig = plt.figure(figsize=(16*2/2.54, 9*2/2.54))
+            gs = gridspec.GridSpec(6, 2, left=0.075, bottom=0.05, right=1, top=0.95, width_ratios=(1.5, 3), hspace=.3, wspace=0.05)
             ax = []
-            ax.append(fig.add_subplot(gs[0, 0]))
-            ax.append(fig.add_subplot(gs[1, 0]))
-            ax[1].plot(times[idxs_oi] - time_oi, fish_freqs[0][idxs_oi], marker='.', color='firebrick')
-            ax[1].plot(times[idxs_oi] - time_oi, fish_freqs[1][idxs_oi], marker='.', color='cornflowerblue')
-            ax[1].set_ylim(min_freq - (max_freq-min_freq)*0.25, max_freq + (max_freq-min_freq)*0.25)
+            ax.append(fig.add_subplot(gs[:, 1]))
+            ax.append(fig.add_subplot(gs[1:3, 0]))
+            ax.append(fig.add_subplot(gs[3:5, 0]))
+
+
+            y00, y01 = np.nanmin(fish_freqs[0][idxs_oi]), np.nanmax(fish_freqs[0][idxs_oi])
+            y10, y11 = np.nanmin(fish_freqs[1][idxs_oi]), np.nanmax(fish_freqs[1][idxs_oi])
+
+            if y01 - y00 < 20:
+                y01 = y00 + 20
+            if y11 - y10 < 20:
+                y11 = y10 + 20
+            freq_span1 = (y01) - (y00)
+            freq_span2 = (y11) - (y10)
+
+            yspan = freq_span1 if freq_span1 > freq_span2 else freq_span2
+
+            ax[1].plot(times[idxs_oi] - time_oi, fish_freqs[0][idxs_oi], marker='.', markersize=4, color='darkorange', lw=2, alpha=0.4)
+            ax[2].plot(times[idxs_oi] - time_oi, fish_freqs[1][idxs_oi], marker='.', markersize=4,color='forestgreen', lw=2, alpha=0.4)
+            ax[1].plot([0, 0], [y00 - yspan * 0.2, y00 + yspan * 1.3], '--', color='k')
+            ax[2].plot([0, 0], [y10 - yspan * 0.2, y10 + yspan * 1.3], '--', color='k')
+
+            ax[1].set_xticks([-30, -15, 0, 15, 30])
+            ax[2].set_xticks([-30, -15, 0, 15, 30])
+            plt.setp(ax[1].get_xticklabels(), visible=False)
+
+            # spectrograms
+            f_mask1 = np.arange(len(fill_freqs))[(fill_freqs >= y00 - yspan * 0.2) & (fill_freqs <= y00 + yspan * 1.3)]
+            f_mask2 = np.arange(len(fill_freqs))[(fill_freqs >= y10 - yspan * 0.2) & (fill_freqs <= y10 + yspan * 1.3)]
+            t_mask = np.arange(len(fill_times))[(fill_times >= time_oi-dt*4) & (fill_times <= time_oi+dt*4)]
+
+            ax[1].imshow(decibel(fill_spec[f_mask1[0]:f_mask1[-1], t_mask[0]:t_mask[-1]][::-1]),
+                                              extent=[-dt*4, dt*4, y00 - yspan * 0.2, y00 + yspan * 1.3],
+                                              aspect='auto',vmin = -100, vmax = -50, alpha=0.7, cmap='jet', interpolation='gaussian')
+            ax[2].imshow(decibel(fill_spec[f_mask2[0]:f_mask2[-1], t_mask[0]:t_mask[-1]][::-1]),
+                                              extent=[-dt*4, dt*4, y10 - yspan * 0.2, y10 + yspan * 1.3],
+                                              aspect='auto',vmin = -100, vmax = -50, alpha=0.7, cmap='jet', interpolation='gaussian')
+
+            ax[1].set_ylim(y00 - yspan * 0.1, y00 + yspan * 1.2)
             ax[1].set_xlim(-dt*3, dt*3)
+            ax[2].set_ylim(y10 - yspan * 0.1, y10 + yspan * 1.2)
+            ax[2].set_xlim(-dt*3, dt*3)
+
             ax[0].set_xticks([])
             ax[0].set_yticks([])
 
             ax[1].tick_params(labelsize=12)
-            ax[1].set_xlabel('time [s]', fontsize=14)
+            ax[2].tick_params(labelsize=12)
+
+            ax[2].set_xlabel('time [s]', fontsize=14)
+            fig.text(0.02, 0.5, 'frequency [Hz]', fontsize=14, va='center', rotation='vertical')
+
             # plt.ion()
             for i in tqdm(np.arange(len(frames_oi))):
                 video.set(cv2.CAP_PROP_POS_FRAMES, int(frames_oi[i]))
                 ret, frame = video.read()
 
+                if i == 250:
+                    dot, = ax[0].plot(0.05, 0.95, 'o', color='firebrick', transform = ax[0].transAxes, markersize=20)
+                if i == 280:
+                    dot.remove()
+
                 if i == 0:
                     img = ax[0].imshow(frame)
-                    line, = ax[1].plot([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
-                                       [min_freq - (max_freq-min_freq)*0.25, max_freq + (max_freq-min_freq)*0.25],
+                    line1, = ax[1].plot([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
+                                       [y00 - yspan * 0.15, y00 + yspan * 1.3],
+                                       color='k', lw=1)
+                    line2, = ax[2].plot([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
+                                       [y10 - yspan * 0.15, y10 + yspan * 1.3],
                                        color='k', lw=1)
                 else:
                     img.set_data(frame)
-                    line.set_data([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
-                                  [min_freq - (max_freq-min_freq)*0.25, max_freq + (max_freq-min_freq)*0.25])
+                    line1.set_data([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
+                                  [y00 - yspan * 0.15, y00 + yspan * 1.3])
+                    line2.set_data([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
+                                  [y10 - yspan * 0.15, y10 + yspan * 1.3])
 
-                # label = ('rise_video/frame%4.f.jpg' % len(glob.glob('rise_video/*.jpg'))).replace(' ', '0')
                 label = (os.path.join(create_video_path, 'frame%4.f.jpg' % len(glob.glob(os.path.join(create_video_path, '*.jpg'))))).replace(' ', '0')
                 plt.savefig(label, dpi=300)
                 # plt.pause(0.001)
 
+            # quit()
             win_lose_str = 'lose' if fish_nr == 1 else 'win'
             # video_name = ("./rise_video/%s_%2.f:%2.f:%2.f.mp4" % (win_lose_str, HH, MM, SS)).replace(' ', '0')
             # command = "ffmpeg -r 25 -i './rise_video/frame%4d.jpg' -vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' -vcodec libx264 -y -an"
diff --git a/trail_analysis.py b/trail_analysis.py
index 5438d0f..432030a 100644
--- a/trail_analysis.py
+++ b/trail_analysis.py
@@ -81,6 +81,7 @@ class Trial(object):
         bins = np.arange(-bw / 2, self.times[-1] + bw / 2, bw)
         self.baseline_freq_times = np.array(bins[:-1] + (bins[1] - bins[0])/2)
         self.baseline_freqs = np.full((2, len(self.baseline_freq_times)), np.nan)
+        self.pct95_freqs = np.full((2, len(self.baseline_freq_times)), np.nan)
 
         for enu, id in enumerate(self.ids):
             for i in range(len(bins) - 1):
@@ -89,6 +90,7 @@ class Trial(object):
                     continue
                 else:
                     self.baseline_freqs[enu][i] = np.nanpercentile(Cf, 5)
+                    self.pct95_freqs[enu][i] = np.nanpercentile(Cf, 75)
 
         self.fish_freq_val = [np.nanmean(x[self.baseline_freq_times > self.light_sec]) for x in self.baseline_freqs]
 
@@ -116,6 +118,30 @@ class Trial(object):
 
             return rise_size
 
+        def correct_rise_idx(rise_peak_idx):
+
+            rise_dt = np.diff(self.times[rise_peak_idx])
+            rise_dt[rise_dt >= 10] = 10
+            rise_dt[rise_dt < 10] = rise_dt[rise_dt < 10] - 1
+            rise_dt = np.append(np.array([10]), rise_dt)
+
+
+            freq_slope = np.full(np.shape(self.fish_freq)[1], np.nan)
+            non_nan_idx = np.arange(len(freq_slope))[~np.isnan(self.fish_freq[i])]
+            freq_slope[non_nan_idx[1:]] = np.diff(self.fish_freq[i][~np.isnan(self.fish_freq[i])])
+
+            corrected_rise_idxs = []
+            for enu, r_idx in enumerate(rise_peak_idx):
+                mask = np.arange(len(freq_slope))[(self.times <= self.times[r_idx]) & (self.times > self.times[r_idx] - rise_dt[enu]) & (~np.isnan(freq_slope))]
+                if len(mask) == 0:
+                    corrected_rise_idxs.append(np.nan)
+                else:
+                    corrected_rise_idxs.append(mask[np.argmax(freq_slope[mask])])
+
+            corrected_rise_idxs = np.array(corrected_rise_idxs)
+
+            return corrected_rise_idxs
+
         for i in range(len(self.fish_freq)):
             rise_peak_idx, trough = detect_peaks(self.fish_freq[i][~np.isnan(self.fish_freq[i])], rise_th)
             non_nan_idx = np.arange(len(self.fish_freq[i]))[~np.isnan(self.fish_freq[i])]
@@ -123,8 +149,11 @@ class Trial(object):
 
             rise_size = check_rises_size(rise_peak_idx)
 
-            self.rise_idxs.append(rise_peak_idx[rise_size >= rise_th])
-            self.rise_size.append(rise_size[rise_size >= rise_th])
+            rise_idx = correct_rise_idx(rise_peak_idx)
+            # print(np.min(np.diff(self.times[rise_peak_idx])))
+
+            self.rise_idxs.append(np.array(rise_idx[(rise_size >= rise_th) & (~np.isnan(rise_idx))], dtype=int))
+            self.rise_size.append(rise_size[(rise_size >= rise_th) & (~np.isnan(rise_idx))])
 
     def update_meta(self):
         entries = self.meta.index.tolist()
@@ -150,26 +179,31 @@ class Trial(object):
 
         for enu, id in enumerate(self.ids):
             c = 'firebrick' if self.winner == enu else 'forestgreen'
-            ax.plot(self.times, self.fish_freq[enu], marker='.', color=c, zorder=1)
-            ax.plot(self.times[np.isnan(self.fish_freq[enu])], self.fish_freq_interp[enu][np.isnan(self.fish_freq[enu])], '.', zorder=1, color=c, alpha=0.25)
-            ax.plot(self.baseline_freq_times, self.baseline_freqs[enu], '--', color='k', zorder=2)
+            ax.plot(self.times/3600, self.fish_freq[enu], marker='.', color=c, zorder=1)
+            ax.plot(self.times[np.isnan(self.fish_freq[enu])]/3600, self.fish_freq_interp[enu][np.isnan(self.fish_freq[enu])], '.', zorder=1, color=c, alpha=0.25)
+            ax.plot(self.baseline_freq_times/3600, self.baseline_freqs[enu], '--', color='k', zorder=2)
+            ax.plot(self.baseline_freq_times/3600, self.pct95_freqs[enu], '--', color='k', zorder=2)
 
-            ax.plot(self.times[self.rise_idxs[enu]], self.fish_freq_interp[enu][self.rise_idxs[enu]], 'o', color='k')
+            ax.plot(self.times[self.rise_idxs[enu]]/3600, self.fish_freq_interp[enu][self.rise_idxs[enu]], 'o', color='k')
 
 
             win_str = '(W)' if self.winner == enu else ''
 
-            ax.text(self.times[-1], self.fish_freq_val[enu]-10, '%.0f' % id + win_str, va ='center', ha='right')
+            ax.text(self.times[-1]/3600, self.fish_freq_val[enu]-10, '%.0f' % id + win_str, va ='center', ha='right')
 
-            ax.set_xlim(0, self.times[-1])
+            ax.set_xlim(0, self.times[-1]/3600)
 
             freq_range = (np.nanmin(self.fish_freq), np.nanmax(self.fish_freq))
             ax.set_ylim(freq_range[0] - 20, freq_range[1] + 10)
+        ax.set_title(self.folder)
         plt.show()
 
     def save(self):
         saveorder = -1 if self.winner == 1 else 1
 
+        if not os.path.exists(os.path.join(self.base_path, self.folder, 'analysis')):
+            os.mkdir(os.path.join(self.base_path, self.folder, 'analysis'))
+
         np.save(os.path.join(self.base_path, self.folder, 'analysis', 'fish_freq.npy'), self.fish_freq[::saveorder])
         np.save(os.path.join(self.base_path, self.folder, 'analysis', 'fish_freq_interp.npy'), self.fish_freq_interp[::saveorder])
 
@@ -196,7 +230,7 @@ class Trial(object):
 
     def frame_to_idx(self, event_frames):
         self.sr = 20000
-        LED_idx = pd.read_csv(os.path.join(self.folder, 'led_idxs.csv'), sep=',')
+        LED_idx = pd.read_csv(os.path.join(self.folder, 'led_idxs.csv'), sep=',', encoding = "utf-7")
 
         led_idx = np.array(LED_idx).T[0]
         led_frame = np.load(os.path.join(self.folder, 'LED_frames.npy'))
@@ -214,15 +248,15 @@ class Trial(object):
 def main():
     parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
     parser.add_argument('-f', type=str, help='single recording analysis', default='')
-    # parser.add_argument("-c", action="store_true", help="check if LED pos is correct")
+    parser.add_argument('-d', "--dev", action="store_true", help="developer mode; no data saved")
     # parser.add_argument('-x', type=int, nargs=2, default=[1272, 1282], help='x-borders of LED detect area (in pixels)')
     # parser.add_argument('-y', type=int, nargs=2, default=[1500, 1516], help='y-borders of LED area (in pixels)')
     args = parser.parse_args()
 
     base_path = '/home/raab/data/2022_competition'
 
-    if os.path.exists(os.path.join(base_path, 'meta.csv')):
-        meta = pd.read_csv(os.path.join(base_path, 'meta.csv'), sep=',', index_col=0)
+    if os.path.exists(os.path.join(base_path, 'meta.csv')) and not args.dev:
+        meta = pd.read_csv(os.path.join(base_path, 'meta.csv'), sep=',', index_col=0, encoding = "utf-7")
     else:
         meta = None
 
@@ -231,7 +265,7 @@ def main():
         folders = [x for x in folders if not '.' in x]
     else:
         folders= [os.path.split(os.path.normpath(args.f))[-1]]
-
+    folders = sorted(folders)
     trials = []
     for folder in folders:
         trial = Trial(folder, base_path, meta, fish_count=2)
@@ -247,9 +281,10 @@ def main():
         trial.rise_detection(rise_th=5)
 
         if meta is not None:
-            trial.update_meta()
-
-        trial.save()
+            if not args.dev:
+                trial.update_meta()
+        if not args.dev:
+            trial.save()
         trial.ilustrate()
         trials.append(trial)