diff --git a/LED_detect.py b/LED_detect.py index 9d16ed1..a8f34db 100644 --- a/LED_detect.py +++ b/LED_detect.py @@ -72,10 +72,10 @@ def main(file_path, check, x, y): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Detect frames of blinking LED in video recordings.') parser.add_argument('file', type=str, help='video file to be analyzed') - parser.add_argument("-c", action="store_true", help="check if LED pos is correct") - parser.add_argument('-x', type=int, nargs=2, default=[1272, 1282], help='x-borders of LED detect area (in pixels)') - parser.add_argument('-y', type=int, nargs=2, default=[1500, 1516], help='y-borders of LED area (in pixels)') + parser.add_argument("-c", '--check', action="store_true", help="check if LED pos is correct") + parser.add_argument('-x', type=int, nargs=2, default=[1240, 1250], help='x-borders of LED detect area (in pixels)') + parser.add_argument('-y', type=int, nargs=2, default=[1504, 1526], help='y-borders of LED area (in pixels)') args = parser.parse_args() import glob - main(args.file, args.c, args.x, args.y) + main(args.file, args.check, args.x, args.y) diff --git a/eval_LED.py b/eval_LED.py index e990a3a..fdd7111 100644 --- a/eval_LED.py +++ b/eval_LED.py @@ -15,7 +15,7 @@ def main(folder): frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) times = np.load(os.path.join(folder, 'times.npy')) - LED_idx = pd.read_csv(os.path.join(folder, 'led_idxs.csv'), sep=',') + LED_idx = pd.read_csv(os.path.join(folder, 'led_idxs.csv'), sep=',', encoding = "utf-7") led_idx = np.array(LED_idx).T[0] led_frame = np.load(os.path.join(folder, 'LED_frames.npy')) diff --git a/event_videos.py b/event_videos.py index 1466e03..372106b 100644 --- a/event_videos.py +++ b/event_videos.py @@ -8,6 +8,7 @@ import glob import argparse from IPython import embed from tqdm import tqdm +from thunderfish.powerspectrum import decibel def main(folder, dt): video_path = glob.glob(os.path.join(folder, '2022*.mp4'))[0] @@ -16,21 +17,23 @@ def main(folder, dt): os.mkdir(create_video_path) video = cv2.VideoCapture(video_path) # was 'cap' - fish_freqs = np.load(os.path.join(folder, 'analysis', 'fish_freq_interp.npy')) - max_freq, min_freq = np.nanmax(fish_freqs), np.nanmin(fish_freqs) + # fish_freqs = np.load(os.path.join(folder, 'analysis', 'fish_freq_interp.npy')) + fish_freqs = np.load(os.path.join(folder, 'analysis', 'fish_freq.npy')) rise_idx = np.load(os.path.join(folder, 'analysis', 'rise_idx.npy')) frame_times = np.load(os.path.join(folder, 'analysis', 'frame_times.npy')) times = np.load(os.path.join(folder, 'times.npy')) + + fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy')) + fill_times = np.load(os.path.join(folder, 'fill_times.npy')) + fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy')) + fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r', + shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') ####################################### for fish_nr in np.arange(2)[::-1]: for idx_oi in tqdm(np.array(rise_idx[fish_nr][~np.isnan(rise_idx[fish_nr])], dtype=int)): - # idx_oi = int(rise_idx[1][10]) time_oi = times[idx_oi] - # embed() - # quit() - HH = int((time_oi / 3600) // 1) MM = int((time_oi - HH * 3600) // 60) SS = int(time_oi - HH * 3600 - MM * 60) @@ -38,40 +41,91 @@ def main(folder, dt): frames_oi = np.arange(len(frame_times))[np.abs(frame_times - time_oi) <= dt] idxs_oi = np.arange(len(times))[np.abs(times - time_oi) <= dt*3] - fig = plt.figure(figsize=(20/2.54, 20/2.54)) - gs = gridspec.GridSpec(2, 1, left=0.1, bottom = 0.1, right=0.95, top=0.95, height_ratios=(4, 1)) + fig = plt.figure(figsize=(16*2/2.54, 9*2/2.54)) + gs = gridspec.GridSpec(6, 2, left=0.075, bottom=0.05, right=1, top=0.95, width_ratios=(1.5, 3), hspace=.3, wspace=0.05) ax = [] - ax.append(fig.add_subplot(gs[0, 0])) - ax.append(fig.add_subplot(gs[1, 0])) - ax[1].plot(times[idxs_oi] - time_oi, fish_freqs[0][idxs_oi], marker='.', color='firebrick') - ax[1].plot(times[idxs_oi] - time_oi, fish_freqs[1][idxs_oi], marker='.', color='cornflowerblue') - ax[1].set_ylim(min_freq - (max_freq-min_freq)*0.25, max_freq + (max_freq-min_freq)*0.25) + ax.append(fig.add_subplot(gs[:, 1])) + ax.append(fig.add_subplot(gs[1:3, 0])) + ax.append(fig.add_subplot(gs[3:5, 0])) + + + y00, y01 = np.nanmin(fish_freqs[0][idxs_oi]), np.nanmax(fish_freqs[0][idxs_oi]) + y10, y11 = np.nanmin(fish_freqs[1][idxs_oi]), np.nanmax(fish_freqs[1][idxs_oi]) + + if y01 - y00 < 20: + y01 = y00 + 20 + if y11 - y10 < 20: + y11 = y10 + 20 + freq_span1 = (y01) - (y00) + freq_span2 = (y11) - (y10) + + yspan = freq_span1 if freq_span1 > freq_span2 else freq_span2 + + ax[1].plot(times[idxs_oi] - time_oi, fish_freqs[0][idxs_oi], marker='.', markersize=4, color='darkorange', lw=2, alpha=0.4) + ax[2].plot(times[idxs_oi] - time_oi, fish_freqs[1][idxs_oi], marker='.', markersize=4,color='forestgreen', lw=2, alpha=0.4) + ax[1].plot([0, 0], [y00 - yspan * 0.2, y00 + yspan * 1.3], '--', color='k') + ax[2].plot([0, 0], [y10 - yspan * 0.2, y10 + yspan * 1.3], '--', color='k') + + ax[1].set_xticks([-30, -15, 0, 15, 30]) + ax[2].set_xticks([-30, -15, 0, 15, 30]) + plt.setp(ax[1].get_xticklabels(), visible=False) + + # spectrograms + f_mask1 = np.arange(len(fill_freqs))[(fill_freqs >= y00 - yspan * 0.2) & (fill_freqs <= y00 + yspan * 1.3)] + f_mask2 = np.arange(len(fill_freqs))[(fill_freqs >= y10 - yspan * 0.2) & (fill_freqs <= y10 + yspan * 1.3)] + t_mask = np.arange(len(fill_times))[(fill_times >= time_oi-dt*4) & (fill_times <= time_oi+dt*4)] + + ax[1].imshow(decibel(fill_spec[f_mask1[0]:f_mask1[-1], t_mask[0]:t_mask[-1]][::-1]), + extent=[-dt*4, dt*4, y00 - yspan * 0.2, y00 + yspan * 1.3], + aspect='auto',vmin = -100, vmax = -50, alpha=0.7, cmap='jet', interpolation='gaussian') + ax[2].imshow(decibel(fill_spec[f_mask2[0]:f_mask2[-1], t_mask[0]:t_mask[-1]][::-1]), + extent=[-dt*4, dt*4, y10 - yspan * 0.2, y10 + yspan * 1.3], + aspect='auto',vmin = -100, vmax = -50, alpha=0.7, cmap='jet', interpolation='gaussian') + + ax[1].set_ylim(y00 - yspan * 0.1, y00 + yspan * 1.2) ax[1].set_xlim(-dt*3, dt*3) + ax[2].set_ylim(y10 - yspan * 0.1, y10 + yspan * 1.2) + ax[2].set_xlim(-dt*3, dt*3) + ax[0].set_xticks([]) ax[0].set_yticks([]) ax[1].tick_params(labelsize=12) - ax[1].set_xlabel('time [s]', fontsize=14) + ax[2].tick_params(labelsize=12) + + ax[2].set_xlabel('time [s]', fontsize=14) + fig.text(0.02, 0.5, 'frequency [Hz]', fontsize=14, va='center', rotation='vertical') + # plt.ion() for i in tqdm(np.arange(len(frames_oi))): video.set(cv2.CAP_PROP_POS_FRAMES, int(frames_oi[i])) ret, frame = video.read() + if i == 250: + dot, = ax[0].plot(0.05, 0.95, 'o', color='firebrick', transform = ax[0].transAxes, markersize=20) + if i == 280: + dot.remove() + if i == 0: img = ax[0].imshow(frame) - line, = ax[1].plot([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi], - [min_freq - (max_freq-min_freq)*0.25, max_freq + (max_freq-min_freq)*0.25], + line1, = ax[1].plot([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi], + [y00 - yspan * 0.15, y00 + yspan * 1.3], + color='k', lw=1) + line2, = ax[2].plot([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi], + [y10 - yspan * 0.15, y10 + yspan * 1.3], color='k', lw=1) else: img.set_data(frame) - line.set_data([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi], - [min_freq - (max_freq-min_freq)*0.25, max_freq + (max_freq-min_freq)*0.25]) + line1.set_data([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi], + [y00 - yspan * 0.15, y00 + yspan * 1.3]) + line2.set_data([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi], + [y10 - yspan * 0.15, y10 + yspan * 1.3]) - # label = ('rise_video/frame%4.f.jpg' % len(glob.glob('rise_video/*.jpg'))).replace(' ', '0') label = (os.path.join(create_video_path, 'frame%4.f.jpg' % len(glob.glob(os.path.join(create_video_path, '*.jpg'))))).replace(' ', '0') plt.savefig(label, dpi=300) # plt.pause(0.001) + # quit() win_lose_str = 'lose' if fish_nr == 1 else 'win' # video_name = ("./rise_video/%s_%2.f:%2.f:%2.f.mp4" % (win_lose_str, HH, MM, SS)).replace(' ', '0') # command = "ffmpeg -r 25 -i './rise_video/frame%4d.jpg' -vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' -vcodec libx264 -y -an" diff --git a/trail_analysis.py b/trail_analysis.py index 5438d0f..432030a 100644 --- a/trail_analysis.py +++ b/trail_analysis.py @@ -81,6 +81,7 @@ class Trial(object): bins = np.arange(-bw / 2, self.times[-1] + bw / 2, bw) self.baseline_freq_times = np.array(bins[:-1] + (bins[1] - bins[0])/2) self.baseline_freqs = np.full((2, len(self.baseline_freq_times)), np.nan) + self.pct95_freqs = np.full((2, len(self.baseline_freq_times)), np.nan) for enu, id in enumerate(self.ids): for i in range(len(bins) - 1): @@ -89,6 +90,7 @@ class Trial(object): continue else: self.baseline_freqs[enu][i] = np.nanpercentile(Cf, 5) + self.pct95_freqs[enu][i] = np.nanpercentile(Cf, 75) self.fish_freq_val = [np.nanmean(x[self.baseline_freq_times > self.light_sec]) for x in self.baseline_freqs] @@ -116,6 +118,30 @@ class Trial(object): return rise_size + def correct_rise_idx(rise_peak_idx): + + rise_dt = np.diff(self.times[rise_peak_idx]) + rise_dt[rise_dt >= 10] = 10 + rise_dt[rise_dt < 10] = rise_dt[rise_dt < 10] - 1 + rise_dt = np.append(np.array([10]), rise_dt) + + + freq_slope = np.full(np.shape(self.fish_freq)[1], np.nan) + non_nan_idx = np.arange(len(freq_slope))[~np.isnan(self.fish_freq[i])] + freq_slope[non_nan_idx[1:]] = np.diff(self.fish_freq[i][~np.isnan(self.fish_freq[i])]) + + corrected_rise_idxs = [] + for enu, r_idx in enumerate(rise_peak_idx): + mask = np.arange(len(freq_slope))[(self.times <= self.times[r_idx]) & (self.times > self.times[r_idx] - rise_dt[enu]) & (~np.isnan(freq_slope))] + if len(mask) == 0: + corrected_rise_idxs.append(np.nan) + else: + corrected_rise_idxs.append(mask[np.argmax(freq_slope[mask])]) + + corrected_rise_idxs = np.array(corrected_rise_idxs) + + return corrected_rise_idxs + for i in range(len(self.fish_freq)): rise_peak_idx, trough = detect_peaks(self.fish_freq[i][~np.isnan(self.fish_freq[i])], rise_th) non_nan_idx = np.arange(len(self.fish_freq[i]))[~np.isnan(self.fish_freq[i])] @@ -123,8 +149,11 @@ class Trial(object): rise_size = check_rises_size(rise_peak_idx) - self.rise_idxs.append(rise_peak_idx[rise_size >= rise_th]) - self.rise_size.append(rise_size[rise_size >= rise_th]) + rise_idx = correct_rise_idx(rise_peak_idx) + # print(np.min(np.diff(self.times[rise_peak_idx]))) + + self.rise_idxs.append(np.array(rise_idx[(rise_size >= rise_th) & (~np.isnan(rise_idx))], dtype=int)) + self.rise_size.append(rise_size[(rise_size >= rise_th) & (~np.isnan(rise_idx))]) def update_meta(self): entries = self.meta.index.tolist() @@ -150,26 +179,31 @@ class Trial(object): for enu, id in enumerate(self.ids): c = 'firebrick' if self.winner == enu else 'forestgreen' - ax.plot(self.times, self.fish_freq[enu], marker='.', color=c, zorder=1) - ax.plot(self.times[np.isnan(self.fish_freq[enu])], self.fish_freq_interp[enu][np.isnan(self.fish_freq[enu])], '.', zorder=1, color=c, alpha=0.25) - ax.plot(self.baseline_freq_times, self.baseline_freqs[enu], '--', color='k', zorder=2) + ax.plot(self.times/3600, self.fish_freq[enu], marker='.', color=c, zorder=1) + ax.plot(self.times[np.isnan(self.fish_freq[enu])]/3600, self.fish_freq_interp[enu][np.isnan(self.fish_freq[enu])], '.', zorder=1, color=c, alpha=0.25) + ax.plot(self.baseline_freq_times/3600, self.baseline_freqs[enu], '--', color='k', zorder=2) + ax.plot(self.baseline_freq_times/3600, self.pct95_freqs[enu], '--', color='k', zorder=2) - ax.plot(self.times[self.rise_idxs[enu]], self.fish_freq_interp[enu][self.rise_idxs[enu]], 'o', color='k') + ax.plot(self.times[self.rise_idxs[enu]]/3600, self.fish_freq_interp[enu][self.rise_idxs[enu]], 'o', color='k') win_str = '(W)' if self.winner == enu else '' - ax.text(self.times[-1], self.fish_freq_val[enu]-10, '%.0f' % id + win_str, va ='center', ha='right') + ax.text(self.times[-1]/3600, self.fish_freq_val[enu]-10, '%.0f' % id + win_str, va ='center', ha='right') - ax.set_xlim(0, self.times[-1]) + ax.set_xlim(0, self.times[-1]/3600) freq_range = (np.nanmin(self.fish_freq), np.nanmax(self.fish_freq)) ax.set_ylim(freq_range[0] - 20, freq_range[1] + 10) + ax.set_title(self.folder) plt.show() def save(self): saveorder = -1 if self.winner == 1 else 1 + if not os.path.exists(os.path.join(self.base_path, self.folder, 'analysis')): + os.mkdir(os.path.join(self.base_path, self.folder, 'analysis')) + np.save(os.path.join(self.base_path, self.folder, 'analysis', 'fish_freq.npy'), self.fish_freq[::saveorder]) np.save(os.path.join(self.base_path, self.folder, 'analysis', 'fish_freq_interp.npy'), self.fish_freq_interp[::saveorder]) @@ -196,7 +230,7 @@ class Trial(object): def frame_to_idx(self, event_frames): self.sr = 20000 - LED_idx = pd.read_csv(os.path.join(self.folder, 'led_idxs.csv'), sep=',') + LED_idx = pd.read_csv(os.path.join(self.folder, 'led_idxs.csv'), sep=',', encoding = "utf-7") led_idx = np.array(LED_idx).T[0] led_frame = np.load(os.path.join(self.folder, 'LED_frames.npy')) @@ -214,15 +248,15 @@ class Trial(object): def main(): parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') parser.add_argument('-f', type=str, help='single recording analysis', default='') - # parser.add_argument("-c", action="store_true", help="check if LED pos is correct") + parser.add_argument('-d', "--dev", action="store_true", help="developer mode; no data saved") # parser.add_argument('-x', type=int, nargs=2, default=[1272, 1282], help='x-borders of LED detect area (in pixels)') # parser.add_argument('-y', type=int, nargs=2, default=[1500, 1516], help='y-borders of LED area (in pixels)') args = parser.parse_args() base_path = '/home/raab/data/2022_competition' - if os.path.exists(os.path.join(base_path, 'meta.csv')): - meta = pd.read_csv(os.path.join(base_path, 'meta.csv'), sep=',', index_col=0) + if os.path.exists(os.path.join(base_path, 'meta.csv')) and not args.dev: + meta = pd.read_csv(os.path.join(base_path, 'meta.csv'), sep=',', index_col=0, encoding = "utf-7") else: meta = None @@ -231,7 +265,7 @@ def main(): folders = [x for x in folders if not '.' in x] else: folders= [os.path.split(os.path.normpath(args.f))[-1]] - + folders = sorted(folders) trials = [] for folder in folders: trial = Trial(folder, base_path, meta, fish_count=2) @@ -247,9 +281,10 @@ def main(): trial.rise_detection(rise_th=5) if meta is not None: - trial.update_meta() - - trial.save() + if not args.dev: + trial.update_meta() + if not args.dev: + trial.save() trial.ilustrate() trials.append(trial)