created tex file and moved code and figures to different folder

This commit is contained in:
2023-08-14 07:29:57 +02:00
parent 7cda7755af
commit 284d30ad18
361 changed files with 10539 additions and 0 deletions

82
code/LED_detect.py Normal file
View File

@@ -0,0 +1,82 @@
import os
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
def check_LED(cap, frame_count, x0, x1, y0, y1):
fig, ax = plt.subplots()
ax.plot([x0, x0], [y0, y1], 'r')
ax.plot([x1, x1], [y0, y1], 'r')
ax.plot([x0, x1], [y0, y0], 'r')
ax.plot([x0, x1], [y1, y1], 'r')
plt.ion()
cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_count / 2))
f = None
try:
for i in np.arange(int(frame_count / 2), frame_count):
ret, frame = cap.read()
if f == None:
f = ax.imshow(frame)
else:
f.set_data(frame)
sum_frame = np.sum(frame, axis=2)
LED_v = np.mean(sum_frame[y0:y1, x0:x1])
print('%.0f: %.1f \n' % (i, LED_v))
plt.pause(0.001)
except KeyboardInterrupt:
plt.close()
quit()
quit()
def main(file_path, check, x, y):
folder, filename = os.path.split(file_path)
cap = cv2.VideoCapture(file_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
x0, x1 = x
y0, y1 = y
if check:
check_LED(cap, frame_count, x0, x1, y0, y1)
###############
# cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_count / 2))
# frame_count = 1000
########################
light_th = 100
LED_val = np.zeros(frame_count)
print('Frame_count: %.0f' % frame_count)
for i in range(frame_count):
if i % 1000 == 0:
print('progress: %.1f' % ((i/frame_count)*100) + '%')
ret, frame = cap.read()
sum_frame = np.sum(frame, axis=2)
LED_val[i] = np.mean(sum_frame[y0:y1, x0:x1])
np.save(os.path.join(folder, 'LED_val.npy'), LED_val)
LED_frames = np.arange(len(LED_val)-1)[(LED_val[:-1] < light_th) & (LED_val[1:] > light_th)]
np.save(os.path.join(folder, 'LED_frames.npy'), LED_frames)
fig, ax = plt.subplots()
ax.plot(np.arange(len(LED_val)), LED_val, color='k')
ax.plot(LED_frames, np.ones(len(LED_frames))*light_th, 'o', color='firebrick')
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Detect frames of blinking LED in video recordings.')
parser.add_argument('file', type=str, help='video file to be analyzed')
parser.add_argument("-c", '--check', action="store_true", help="check if LED pos is correct")
parser.add_argument('-x', type=int, nargs=2, default=[675, 695], help='x-borders of LED detect area (in pixels)')
parser.add_argument('-y', type=int, nargs=2, default=[350, 360], help='y-borders of LED area (in pixels)')
args = parser.parse_args()
import glob
main(args.file, args.check, args.x, args.y)

53
code/README.md Normal file
View File

@@ -0,0 +1,53 @@
# How to competition experiment
__Workflow__ (Python-scripts/applications):
1) wavetracker.trackingGUI
2) wavetracker.EODsorter
3) LED_detect.py
4) eval_LED.py
5) trial_analysis.py
6) event_videos.py (optional)
## Raw data analysis using the wavetracker-modul
### trackingGUI.py
__Frequency extraction and tracking__
- open Raw-file (traces-grid1.raw)
- 'Spectrogram'-settings:
- overlap fraction: 0.8
- frequency resolution: 1
- check 'Auto-save'; press 'Run'
__Fine spectrogram__
- repeat steps above but press 'Calc. fine spec' instead of Run
- fine spec data saved in /home/"user"/analysis/"filename"
### EODsorter.py
- load dataset/folder
- correct tracked EOD traces
- fill EOD traces
- fine spec data needs to be manually added to the dataset-folder
## Competition trial analysis
### trail_analysis.py
- Detection of winners, their EODf traces, rises, etc. Results stored in "data-path"/analysis.
- (optional) Meta.csv file in base-path of analyzed data. Creates entries for each
analyzed recording (index = file names) and stores Meta-data. Manual competation suggested.
## Video analysis
### LED_detect.py
- Detect blinking LED (emitted by electric recording setup). Used for synchronization.
- "-c" argument to identify correct detection area for LED
- '-x' (tuple) borders of LED detection window on X-axis (in pixels)
- '-y' (tuple) borders of LED detection window on Y-axis (in pixels)
### eval_LED.py
- creates time vector to synchronize electric and video recording
- for each frame contains a time-point (in s) that corresponds to the electric recordings.
## Rise videos (optional)
- generates for each detected rise a short video showing the fish's behavior around the rise event.
- sorted in 'base-path'/rise_video.

410
code/complete_analysis.py Normal file
View File

@@ -0,0 +1,410 @@
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import itertools
from event_time_correlations import load_and_converete_boris_events
from tqdm import tqdm
import numpy as np
import pandas as pd
import os
import sys
import glob
from IPython import embed
def load_frame_times(trial_path):
t_filepath = glob.glob(os.path.join(trial_path, '*.dat'))
if len(t_filepath) == 0:
return np.array([])
else:
t_filepath = t_filepath[0]
f = open(t_filepath, 'r')
frame_t = []
for line in f.readlines():
t = sum(x * float(t) for x, t in zip([3600, 60, 1], line.replace('\n', '').split(":")))
frame_t.append(t)
return np.array(frame_t)
# def load_and_converete_boris_events(trial_path, recording, sr, video_stated_FPS=25):
# def converte_video_frames_to_grid_idx(event_frames, led_frames, led_idx):
# event_idx_grid = (event_frames - led_frames[0]) / (led_frames[-1] - led_frames[0]) * (led_idx[-1] - led_idx[0]) + led_idx[0]
# return event_idx_grid
#
# # idx in grid-recording
# led_idx = pd.read_csv(os.path.join(trial_path, 'led_idxs.csv'), header=None).iloc[:, 0].to_numpy()
# # frames where LED gets switched on
# led_frames = np.load(os.path.join(trial_path, 'LED_frames.npy'))
#
# times, behavior, t_ag_on_off, t_contact, video_FPS = load_boris(trial_path, recording)
#
# contact_frame = np.array(np.round(t_contact * video_FPS), dtype=int)
# ag_on_off_frame = np.array(np.round(t_ag_on_off * video_FPS), dtype=int)
#
# # led_t_GRID = led_idx / sr
# contact_t_GRID = converte_video_frames_to_grid_idx(contact_frame, led_frames, led_idx) / sr
# ag_on_off_t_GRID = converte_video_frames_to_grid_idx(ag_on_off_frame, led_frames, led_idx) / sr
#
# return contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames
# def load_boris(trial_path, recording):
# boris_file = '-'.join(recording.split('-')[:3]) + '.csv'
#
# data = pd.read_csv(os.path.join(trial_path, boris_file))
# times = data['Start (s)']
# behavior = data['Behavior']
#
# t_ag_on = times[behavior == 0]
# t_ag_off = times[behavior == 1]
#
# t_ag_on_off = []
# for t in t_ag_on:
# t1 = np.array(t_ag_off)[t_ag_off > t]
# if len(t1) >= 1:
# t_ag_on_off.append(np.array([t, t1[0]]))
#
# t_contact = times[behavior == 2]
#
# return times, behavior, np.array(t_ag_on_off), t_contact.to_numpy(), data['FPS'][0]
def get_baseline_freq(fund_v, idx_v, times, ident_v, idents = None, binwidth = 300):
if not hasattr(idents, '__len__'):
idents = np.unique(ident_v[~np.isnan(ident_v)])
base_freqs = []
for id in idents:
f = fund_v[ident_v == id]
t = times[idx_v[ident_v == id]]
bins = np.arange(-binwidth/2, times[-1] + binwidth/2, binwidth)
base_f = np.full(len(bins)-1, np.nan)
for i in range(len(bins)-1):
Cf = f[(t > bins[i]) & (t <= bins[i+1])]
if len(Cf) == 0:
continue
else:
base_f[i] = np.percentile(Cf, 5)
base_freqs.append(base_f)
return np.array(base_freqs), np.array(bins[:-1] + (bins[1] - bins[0])/2)
def q10(f1, f2, t1, t2):
return(f2/f1)**(10/(t2 - t1))
def frequency_q10_compensation(baseline_freqs : np.ndarray,
baseline_freq_times : np.ndarray,
temp : np.ndarray,
temp_t : np.ndarray,
light_start_sec : float):
"""
Compute baseline frequency at 25 degree Celsius using Q10 formula. Q10 values are computed between each frequency-
temperature pair after light_start_sec (since frequency modulations can be assumed minimal during light). Q10-
compensated baseline freqs are computed for all values in baseline_freqs using the median q10 value computed previously.
Parameters
----------
baseline_freqs: 2D-array: For each fish and each time in baseline_freq_times a correpsonding frequency in Hz.
baseline_freq_times: 1D-array: Time stamps corresponding to baseline_freq.
temp: 1D-array: temperature values detected at timespamps temp_t.
temp_t: 1D-array: corresponding time stamps
light_start_sec: time when light is switched on and frequency modulations can be assumed to be minimal. Q10 values
only calculated for timestamps after light_start_sec
Returns
-------
"""
q10_lit = 1.56
q10_comp_freq = []
q10_vals = []
for bf in baseline_freqs:
Cbf = np.copy(bf)
Ctemp = []
for base_line_time in baseline_freq_times:
Ctemp.append(temp[np.argmin(np.abs(temp_t - base_line_time))])
Ctemp = np.array(Ctemp)
q10s = []
for i, j in itertools.combinations(range(len(Cbf)), r=2):
if Cbf[i] == Cbf[j] or Ctemp[i] == Ctemp[j]:
# q10 with same values is useless
continue
if baseline_freq_times[i] < light_start_sec or baseline_freq_times[j] < light_start_sec:
# too much frequency changes due to rises in first part of rec !!!
continue
# if np.abs(Ctemp[i] - Ctemp[j]) < 0.5:
# continue
Cq10 = q10(Cbf[i], Cbf[j], Ctemp[i], Ctemp[j])
q10s.append(Cq10)
# q10_comp_freq.append(Cbf * np.median(q10s) ** ((25 - Ctemp) / 10))
q10_comp_freq.append(Cbf * q10_lit ** ((25 - Ctemp) / 10))
q10_vals.append(np.median(q10s))
print(f'Q10-values: {q10_vals[0]:.2f} {q10_vals[1]:.2f}')
return q10_comp_freq, q10_vals
def get_temperature(folder_path):
temp_file = pd.read_csv(os.path.join(folder_path, 'temperatures.csv'), sep=';')
temp_t = temp_file[temp_file.keys()[0]]
temp = temp_file[temp_file.keys()[1]]
temp_t = np.array(temp_t)
temp = np.array(temp)
if type(temp[-1]).__name__== 'str':
temp = np.array(temp[:-1], dtype=float)
temp_t = np.array(temp_t[:-1], dtype=int)
return np.array(temp_t), np.array(temp)
def main(base_path=None):
colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C', '#1E90FF']
female_color, male_color = '#e74c3c', '#3498db'
Wc, Lc = 'darkgreen', '#3673A4'
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'example_trials')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'example_trials'))
# trials_meta = pd.read_csv('order_meta.csv')
trials_meta = pd.read_csv(os.path.join(base_path, 'order_meta.csv'))
# fish_meta = pd.read_csv('id_meta.csv')
fish_meta = pd.read_csv(os.path.join(base_path, 'id_meta.csv'))
fish_meta['mean_w'] = np.nanmean(fish_meta.loc[:, ['w1', 'w2', 'w3']], axis=1)
fish_meta['mean_l'] = np.nanmean(fish_meta.loc[:, ['l1', 'l2', 'l3']], axis=1)
video_stated_FPS = 25 # cap.get(cv2.CAP_PROP_FPS)
sr = 20_000
light_start_sec = 3*60*60
trial_summary = pd.DataFrame(columns=['recording', 'group', 'win_fish', 'lose_fish', 'win_ID', 'lose_ID',
'sex_win', 'sex_lose', 'size_win', 'size_lose', 'dsize', 'EODf_win', 'EODf_lose', 'dEODf',
'exp_win', 'exp_lose', 'chirps_win', 'chirps_lose', 'rises_win', 'rises_lose',
'chase_count', 'contact_count', 'med_chase_dur', 'comp_dur0', 'comp_dur1',
'draw'])
trial_summary_row = {f'{s}':None for s in trial_summary.keys()}
for trial_idx in tqdm(np.arange(len(trials_meta)), desc='Trials'):
video_eval = True
group = trials_meta['group'][trial_idx]
recording = trials_meta['recording'][trial_idx][1:-1]
print('')
print(recording)
rec_id1 = trials_meta['rec_id1'][trial_idx]
rec_id2 = trials_meta['rec_id2'][trial_idx]
if group < 3:
continue
trial_path = os.path.join(base_path, recording)
if not os.path.exists(trial_path):
continue
if group < 5:
video_eval = False
if not os.path.exists(os.path.join(trial_path, 'led_idxs.csv')):
video_eval = False
if not os.path.exists(os.path.join(trial_path, 'LED_frames.npy')):
video_eval = False
#############################################################################################################
### meta collect
if (winner_fish := trials_meta['winner'][trial_idx]) == -1:
pass
elif np.isnan(winner_fish):
continue
elif winner_fish != trials_meta['fish1'][trial_idx] and winner_fish != trials_meta['fish2'][trial_idx]:
embed()
quit()
print(f'not participating winner in {recording}!!!')
continue
win_id = rec_id1 if trials_meta['fish1'][trial_idx] == trials_meta['winner'][trial_idx] else rec_id2
lose_id = rec_id2 if trials_meta['fish1'][trial_idx] == trials_meta['winner'][trial_idx] else rec_id1
f1_length = float(fish_meta['mean_l'][(fish_meta['group'] == trials_meta['group'][trial_idx]) &
(fish_meta['fish'] == trials_meta['fish1'][trial_idx])])
f2_length = float(fish_meta['mean_l'][(fish_meta['group'] == trials_meta['group'][trial_idx]) &
(fish_meta['fish'] == trials_meta['fish2'][trial_idx])])
win_l = f1_length if trials_meta['fish1'][trial_idx] == trials_meta['winner'][trial_idx] else f2_length
lose_l = f2_length if trials_meta['fish1'][trial_idx] == trials_meta['winner'][trial_idx] else f1_length
win_exp = trials_meta['exp1'][trial_idx] if trials_meta['winner'][trial_idx] == trials_meta['fish1'][trial_idx] else trials_meta['exp2'][trial_idx]
lose_exp = trials_meta['exp2'][trial_idx] if trials_meta['winner'][trial_idx] == trials_meta['fish1'][trial_idx] else trials_meta['exp1'][trial_idx]
#############################################################################################################
fund_v = np.load(os.path.join(trial_path, 'fund_v.npy'))
ident_v = np.load(os.path.join(trial_path, 'ident_v.npy'))
idx_v = np.load(os.path.join(trial_path, 'idx_v.npy'))
times = np.load(os.path.join(trial_path, 'times.npy'))
if len(uid:=np.unique(ident_v[~np.isnan(ident_v)])) >2:
print(f'to many ids: {len(uid)}')
print(f'ids in recording: {uid[0]:.0f} {uid[1]:.0f}')
print(f'ids in meta: {rec_id1:.0f} {rec_id2:.0f}')
meta_id_in_uid = list(map(lambda x: x in uid, [rec_id1, rec_id2]))
if ~np.all(meta_id_in_uid):
continue
ids = np.load(os.path.join(trial_path, 'analysis', 'ids.npy'))
sorter = -1 if win_id != ids[0] else 1
temp_t, temp = get_temperature(trial_path)
baseline_freqs = np.load(os.path.join(trial_path, 'analysis', 'baseline_freqs.npy'))[::sorter]
baseline_freq_times = np.load(os.path.join(trial_path, 'analysis', 'baseline_freq_times.npy'))
q10_comp_freq, q10_vals = frequency_q10_compensation(baseline_freqs, baseline_freq_times, temp, temp_t, light_start_sec=light_start_sec)
#############################################################################################################
### communication
got_chirps = False
if os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')):
chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy'))
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
chirp_times = [chirp_t[chirp_ids == win_id], chirp_t[chirp_ids == lose_id]]
got_chirps = True
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
#############################################################################################################
### physical behavior
med_chase_dur = contact_n = chase_n = comp_dur0 = comp_dur1 = -1
if video_eval:
contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames = load_and_converete_boris_events(trial_path, recording, sr)
only_contact_mask = np.ones_like(contact_t_GRID, dtype=bool)
for enu, ct in enumerate(contact_t_GRID):
for Cag_on_off_t in ag_on_off_t_GRID:
if Cag_on_off_t[0] <= ct <= Cag_on_off_t[1]:
only_contact_mask[enu] = 0
break
elif ct < Cag_on_off_t[0]:
break
contact_t_solely = contact_t_GRID[only_contact_mask]
ag_offs = np.concatenate((contact_t_GRID, ag_on_off_t_GRID[:, 1]))
ag_offs = ag_offs[np.argsort(ag_offs)]
med_chase_dur = np.median(ag_on_off_t_GRID[:,1] - ag_on_off_t_GRID[:,0])
contact_n = len(contact_t_GRID)
chase_n = len(ag_on_off_t_GRID)
comp_dur0 = ag_offs[2]
comp_dur1 = ag_offs[2] - ag_offs[0]
win_fish_no = trials_meta['fish1'][trial_idx] if trials_meta['fish1'][trial_idx] == trials_meta['winner'][trial_idx] else trials_meta['fish2'][trial_idx]
lose_fish_no = trials_meta['fish2'][trial_idx] if trials_meta['fish1'][trial_idx] == trials_meta['winner'][trial_idx] else trials_meta['fish1'][trial_idx]
trial_summary.loc[len(trial_summary)] = trial_summary_row
trial_summary.iloc[-1] = {'recording': recording,
'group': trials_meta['group'][trial_idx],
'win_fish': win_fish_no,
'lose_fish': lose_fish_no,
'win_ID': win_id,
'lose_ID': lose_id,
'sex_win': 'n',
'sex_lose': 'n',
'size_win': win_l,
'size_lose': lose_l,
'dsize': win_l - lose_l,
'EODf_win': np.nanmedian(q10_comp_freq[0]),
'EODf_lose': np.nanmedian(q10_comp_freq[1]),
'dEODf': np.nanmedian(q10_comp_freq[0]) - np.nanmedian(q10_comp_freq[1]),
'exp_win': win_exp,
'exp_lose': lose_exp,
'chirps_win': len(chirp_times[0]),
'chirps_lose': len(chirp_times[1]),
'rises_win': len(rise_idx_int[0]),
'rises_lose': len(rise_idx_int[1]),
'draw': 1 if trials_meta['winner'][trial_idx] == -1 else 0,
'chase_count': chase_n,
'contact_count': contact_n,
'med_chase_dur': med_chase_dur,
'comp_dur0': comp_dur0,
'comp_dur1': comp_dur1
}
# embed()
###############################################################################
fig = plt.figure(figsize=(30/2.54, 18/2.54))
gs = gridspec.GridSpec(2, 1, left = 0.1, bottom = 0.1, right=0.95, top=0.95, height_ratios=[1, 3], hspace=0)
ax = []
ax.append(fig.add_subplot(gs[0, 0]))
ax.append(fig.add_subplot(gs[1, 0], sharex=ax[0]))
####################################################
### traces
ax[1].plot(times[idx_v[ident_v == win_id]] / 3600, fund_v[ident_v == win_id], color=Wc, label=f'ID {win_id} {np.nanmedian(q10_comp_freq[0]):.2f}Hz')
ax[1].plot(times[idx_v[ident_v == lose_id]] / 3600, fund_v[ident_v == lose_id], color=Lc, label=f'ID {lose_id} {np.nanmedian(q10_comp_freq[1]):.2f}Hz')
# ax[1].plot(baseline_freq_times / 3600, q10_comp_freq[0], '--', color=Wc, lw=1)
# ax[1].plot(baseline_freq_times / 3600, q10_comp_freq[1], '--', color=Lc, lw=1)
# ax[1].plot(times[idx_v[ident_v == lose_id]] / 3600, fund_v[ident_v == lose_id], color=Lc)
min_f, max_f = np.min(fund_v[~np.isnan(ident_v)]), np.nanmax(fund_v[~np.isnan(ident_v)])
ax[1].set_ylim(min_f-50, max_f+50)
ax[1].set_xlim(times[0]/3600, times[-1]/3600)
plt.setp(ax[0].get_xticklabels(), visible=False)
ax_m = ax[1].twinx()
ax_m.plot(temp_t/3600, temp, '--', lw=2, color='tab:red')
ylim0, ylim1 = ax[1].get_ylim()
ax_m.set_ylim(np.nanmedian(temp) - (ylim1-ylim0) / 40 / 2, np.nanmedian(temp) + (ylim1-ylim0) / 40 / 2)
ax[1].legend(loc='upper right', bbox_to_anchor=(1, 1), title=r'EODf$_{25}$')
####################################################
### behavior
if video_eval:
ax[0].plot(contact_t_GRID / 3600, np.ones_like(contact_t_GRID) , '|', markersize=10, color='k')
ax[0].plot(ag_on_off_t_GRID[:, 0] / 3600, np.ones_like(ag_on_off_t_GRID[:, 0]) * 2, '|', markersize=10, color='firebrick')
ax[0].plot(times[rise_idx_int[0]] / 3600, np.ones_like(rise_idx_int[0]) * 4, '|', markersize=10, color=Wc)
ax[0].plot(times[rise_idx_int[1]] / 3600, np.ones_like(rise_idx_int[1]) * 5, '|', markersize=10, color=Lc)
if got_chirps:
ax[0].plot(chirp_times[0] / 3600, np.ones_like(chirp_times[0]) * 7, '|', markersize=10, color=Wc)
ax[0].plot(chirp_times[1] / 3600, np.ones_like(chirp_times[1]) * 8, '|', markersize=10, color=Lc)
ax[0].set_ylim(0, 9)
ax[0].set_yticks([1, 2, 4, 5, 7, 8])
ax[0].set_yticklabels(['contact', 'chase', r'rise$_{win}$', r'rise$_{lose}$', r'chirp$_{win}$', r'chirp$_{lose}$'])
fig.suptitle(f'{recording}')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'example_trials', f'{recording}.png'), dpi=300)
# plt.savefig(os.path.join(os.path.join(os.path.split(__file__)[0], 'figures', f'{recording}.png')), dpi=300)
plt.close()
for g in pd.unique(trial_summary['group']):
fish_no = np.unique(np.concatenate((trial_summary['win_fish'][trial_summary['group'] == g],
trial_summary['lose_fish'][trial_summary['group'] == g])))
for f in fish_no:
fish_EODf25 = np.concatenate((trial_summary['EODf_lose'][(trial_summary['group'] == g) & (trial_summary['lose_fish'] == f)],
trial_summary['EODf_win'][(trial_summary['group'] == g) & (trial_summary['win_fish'] == f)]))
if np.nanmedian(fish_EODf25) < 730:
sex = 'f'
else:
sex = 'm'
trial_summary['sex_win'][(trial_summary['group'] == g) & (trial_summary['win_fish'] == f)] = sex
trial_summary['sex_lose'][(trial_summary['group'] == g) & (trial_summary['lose_fish'] == f)] = sex
trial_summary.to_csv(os.path.join(base_path, 'trial_summary.csv'))
pass
if __name__ == '__main__':
# main("/home/raab/data/mount_data/")
main("/home/raab/data/2020_competition_mount")

450
code/ethogram.py Normal file
View File

@@ -0,0 +1,450 @@
import os
import sys
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import pandas as pd
import scipy.stats as scp
import networkx as nx
from thunderfish.powerspectrum import decibel
from IPython import embed
from event_time_correlations import load_and_converete_boris_events
glob_colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C', '#1E90FF', 'k']
def plot_transition_matrix(matrix, labels):
fig = plt.figure(figsize=(20/2.54, 20/2.54))
#gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.1, right=0.9, top=0.95, wspace=0.1, width_ratios=[8, 1])
gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.925, top=0.95)
ax = fig.add_subplot(gs[0, 0])
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
# cax = fig.add_subplot(gs[0, 1])
im = ax.imshow(matrix)
ax.set_xticks(list(range(len(matrix))))
ax.set_yticks(list(range(len(matrix))))
ax.set_xticklabels(labels, rotation=45)
ax.set_yticklabels(labels)
fig.colorbar(im, cax=cax, orientation='vertical')
ax.tick_params(labelsize=10)
cax.tick_params(labelsize=10)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'event_counts' + '.png'), dpi=300)
plt.close()
def plot_transition_diagram(matrix, labels, node_size, ax, threshold=5,
color_by_origin=False, color_by_target=False, title=''):
matrix[matrix <= threshold] = 0
matrix = np.around(matrix, decimals=1)
Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
node_labels = dict(zip(Graph, labels))
# Graph = nx.relabel_nodes(Graph, node_labels)
edge_labels = nx.get_edge_attributes(Graph, 'weight')
positions = nx.circular_layout(Graph)
positions2 = nx.circular_layout(Graph)
for p in positions:
positions2[p][0] *= 1.2
positions2[p][1] *= 1.2
# ToDo: nodes
nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)])
nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels, ax=ax)
# google networkx drawing to get better graphs with networkx
# nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax)
# # ToDo: edges
edge_width = np.array([x / 5 for x in [*edge_labels.values()]])
if color_by_origin:
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]]
elif color_by_target:
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 1]]
else:
edge_colors = 'k'
edge_width[edge_width >= 6] = 6
nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
arrows=True, arrowsize=20,
min_target_margin=25, min_source_margin=25, connectionstyle="arc3, rad=0.025",
ax=ax, edge_color=edge_colors)
nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.set_xlim(-1.3, 1.3)
ax.set_ylim(-1.3, 1.3)
ax.set_title(title, fontsize=12)
def create_marcov_matrix(individual_event_times, individual_event_labels):
event_times = []
event_labels = []
for ll, t in zip(individual_event_labels, individual_event_times):
event_times.extend(t)
event_labels.extend(np.full(len(t), ll))
time_sorter = np.argsort(event_times)
event_times = np.array(event_times)[time_sorter]
event_labels = np.array(event_labels)[time_sorter]
marcov_matrix = np.zeros((len(individual_event_labels) + 1, len(individual_event_labels) + 1))
for enu_ori, label_ori in enumerate(individual_event_labels):
for enu_tar, label_tar in enumerate(individual_event_labels):
n = len(event_times[:-1][(event_labels[:-1] == label_ori) & (event_labels[1:] == label_tar) & (
np.diff(event_times) <= 5)])
marcov_matrix[enu_ori, enu_tar] = n
for enu_tar, label_tar in enumerate(individual_event_labels):
n = len(event_times[:-1][(event_labels[1:] == label_tar) & (np.diff(event_times) > 5)])
marcov_matrix[-1, enu_tar] = n
marcov_matrix[-1, 5] = 0
individual_event_labels.append('void')
### get those cases where ag_on does not point to event and no event points to corresponding ag_off ... add thise cases in marcov matrix
chase_on_idx = np.where(event_labels == individual_event_labels[4])[0]
chase_off_idx = np.where(event_labels == individual_event_labels[5])[0]
helper_mask = np.ones_like(chase_on_idx)
helper_mask[np.diff(event_times)[chase_on_idx] <= 5] = 0
helper_mask[np.diff(event_times)[chase_off_idx - 1] <= 5] = 0
marcov_matrix[4, 5] += np.sum(helper_mask)
return marcov_matrix
def fine_spec_plot(ax, example_1_path, trial_summary, example_ag_on_off):
ex1_df_idx = trial_summary[trial_summary['recording'] == os.path.split(example_1_path)[-1]].index.to_numpy()[0]
lose_id = trial_summary.iloc[ex1_df_idx]['lose_ID']
fine_spec_shape = np.load(os.path.join(example_1_path, 'fill_spec_shape.npy'))
fine_spec = np.memmap(os.path.join(example_1_path, 'fill_spec.npy'), dtype='float', mode='r', shape=(fine_spec_shape[0], fine_spec_shape[1]), order='F')
fine_times = np.load(os.path.join(example_1_path, 'fill_times.npy'))
spec_freqs = np.load(os.path.join(example_1_path, 'fill_freqs.npy'))
times = np.load(os.path.join(example_1_path, 'times.npy'))
fund_v = np.load(os.path.join(example_1_path, 'fund_v.npy'))
ident_v = np.load(os.path.join(example_1_path, 'ident_v.npy'))
idx_v = np.load(os.path.join(example_1_path, 'idx_v.npy'))
# artificial_t_axis = np.linspace(times[0], times[-1], spec.shape[1])
# artificial_f_axis = np.linspace(0, 2000, spec.shape[0])
# plt.pcolormesh(artificial_t_axis, artificial_f_axis, decibel(spec), vmin=-100, vmax=-50)
lose_freq_in_snippet = fund_v[(ident_v == lose_id) & (times[idx_v] > example_ag_on_off[0][0]-5) & (times[idx_v] < example_ag_on_off[0][1]+5)]
max_f, min_f = np.max(lose_freq_in_snippet) + 25, np.min(lose_freq_in_snippet) - 25
f_idx0 = np.where(spec_freqs <= min_f)[0][-1]
f_idx1 = np.where(spec_freqs >= max_f)[0][0]
t_idx0 = np.where(fine_times <= example_ag_on_off[0][0] - 5)[0][-1]
t_idx1 = np.where(fine_times >= example_ag_on_off[0][0] + 4)[0][0]
ax.pcolormesh(fine_times[t_idx0:t_idx1+1] - example_ag_on_off[0][0], spec_freqs[f_idx0:f_idx1+1],
decibel(fine_spec[f_idx0:f_idx1+1, t_idx0:t_idx1+1]))
t_idx0 = np.where(fine_times <= example_ag_on_off[0][1] - 5)[0][-1]
t_idx1 = np.where(fine_times >= example_ag_on_off[0][1] + 5)[0][0]
ax.pcolormesh(fine_times[t_idx0:t_idx1+1] - example_ag_on_off[0][1] + 10, spec_freqs[f_idx0:f_idx1+1],
decibel(fine_spec[f_idx0:f_idx1+1, t_idx0:t_idx1+1]))
ax.fill_between([4, 5], [spec_freqs[f_idx0], spec_freqs[f_idx0]], [spec_freqs[f_idx1], spec_freqs[f_idx1]], color='white')
def main(base_path):
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov'))
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
# trial_summary = trial_summary[chirp_notes['good'] == 1]
trial_mask = chirp_notes['good'] == 1
all_marcov_matrix = []
all_event_counts = []
all_agonistic_categorie = []
# agonistic categorie plot
# fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
# gs = gridspec.GridSpec(2, 1, left=0.1, bottom=0.1, right=0.9, top=0.95, height_ratios=[1, 4], hspace=0)
# ax = fig.add_subplot(gs[1, 0])
# ax_spec = fig.add_subplot(gs[0, 0], sharex=ax)
# plt.setp(ax_spec.get_xticklabels(), visible=False)
#
# for i in range(1, 5):
# ax.fill_between([0, 4], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey')
# ax.fill_between([5, 10], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey')
#
# fill_dots = np.arange(4, 5.1, 0.125)
# ax.plot(fill_dots, np.ones_like(fill_dots)*i, '.', color='tab:grey', markersize=3)
got_examples = [False, False, False]
example_ag_on_off = [[], [], []]
example_chirp_times = [[], [], []]
example_rise_times = [[], [], []]
example_1_path = ''
example_skips = [15, 4, 3] #3, 5, 9, 15, 19
for index, trial in trial_summary.iterrows():
trial_path = os.path.join(base_path, trial['recording'])
if not trial_mask[index]:
continue
if trial['group'] < 5:
continue
if not os.path.exists(os.path.join(trial_path, 'led_idxs.csv')):
continue
if not os.path.exists(os.path.join(trial_path, 'LED_frames.npy')):
continue
if trial['draw'] == 1:
continue
ids = np.load(os.path.join(trial_path, 'analysis', 'ids.npy'))
times = np.load(os.path.join(trial_path, 'times.npy'))
sorter = -1 if trial['win_ID'] != ids[0] else 1
### event times --> BORIS behavior
contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames = \
load_and_converete_boris_events(trial_path, trial['recording'], sr=20_000)
### communication
if not os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')):
continue
chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy'))
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]]
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]]
# trial marcov matrix
individual_event_times = [chirp_times[1], rise_times[1], chirp_times[0], rise_times[0], ag_on_off_t_GRID[:, 0],
ag_on_off_t_GRID[:, 1], contact_t_GRID]
individual_event_labels = [r'chirp$_{lose}$', r'rise$_{lose}$', r'chirp$_{win}$', r'rise$_{win}$',
r'chace$_{on}$', r'chace$_{off}$', 'contact']
marcov_matrix = create_marcov_matrix(individual_event_times, individual_event_labels)
all_marcov_matrix.append(marcov_matrix)
# compute and store trial event counts
event_counts = np.array(list(map(lambda x: len(x), individual_event_times)))
event_counts = np.append(event_counts, marcov_matrix[-1].sum())
all_event_counts.append(event_counts)
# agonistic categories
agonitic_categorie = np.zeros(len(ag_on_off_t_GRID))
for enu, (chase_on_time, chase_off_time) in enumerate(ag_on_off_t_GRID):
chase_dur = chase_off_time - chase_on_time
chirp_dt = chase_dur if chase_dur < 5 else 5
max_dt = 5
# check if rise before chase / chirp at end
rise_before, chirp_arround_end = False, False
if np.any(((chase_on_time - rise_times[1]) > 0) & ((chase_on_time - rise_times[1]) < max_dt)):
rise_times_oi = rise_times[1][((chase_on_time - rise_times[1]) > 0) & ((chase_on_time - rise_times[1]) < max_dt)]
rise_before = True
if np.any( ((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)):
chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chase_dur) & ((chirp_times[1] - chase_off_time) < max_dt)]
chirp_arround_end = True
# define agonistic categorie based on rise/chirp occurance
if rise_before:
if chirp_arround_end:
agonitic_categorie[enu] = 1
else:
agonitic_categorie[enu] = 2
else:
if chirp_arround_end:
agonitic_categorie[enu] = 3
else:
agonitic_categorie[enu] = 4
if agonitic_categorie[enu] == 1 and not got_examples[0]:
if chase_dur > 10:
if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0):
if example_skips[int(agonitic_categorie[enu] - 1)] == 0:
example_ag_on_off[int(agonitic_categorie[enu] - 1)].extend([chase_on_time, chase_off_time])
example_chirp_times[int(agonitic_categorie[enu] - 1)].extend(chirp_time_oi)
example_rise_times[int(agonitic_categorie[enu] - 1)].extend(rise_times_oi)
example_1_path = trial_path
got_examples[0] = True
else:
example_skips[int(agonitic_categorie[enu] - 1)] -= 1
elif agonitic_categorie[enu] == 2 and not got_examples[1]:
if chase_dur > 10:
if example_skips[int(agonitic_categorie[enu] - 1)] == 0:
example_ag_on_off[int(agonitic_categorie[enu] - 1)].extend([chase_on_time, chase_off_time])
example_rise_times[int(agonitic_categorie[enu] - 1)].extend(rise_times_oi)
got_examples[1] = True
else:
example_skips[int(agonitic_categorie[enu] - 1)] -= 1
elif agonitic_categorie[enu] == 3 and not got_examples[2]:
if chase_dur > 10:
if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0):
if example_skips[int(agonitic_categorie[enu] - 1)] == 0:
example_ag_on_off[int(agonitic_categorie[enu] - 1)].extend([chase_on_time, chase_off_time])
example_chirp_times[int(agonitic_categorie[enu] - 1)].extend(chirp_time_oi)
got_examples[2] = True
else:
example_skips[int(agonitic_categorie[enu] - 1)] -= 1
else:
pass
all_agonistic_categorie.append(agonitic_categorie)
### agonistic categorie example figure
fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
gs = gridspec.GridSpec(2, 1, left=0.1, bottom=0.1, right=0.9, top=0.9, height_ratios=[1, 4], hspace=0)
ax = fig.add_subplot(gs[1, 0])
ax_spec = fig.add_subplot(gs[0, 0], sharex=ax)
plt.setp(ax_spec.get_xticklabels(), visible=False)
for i in range(1, 5):
ax.fill_between([0, 4], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey')
ax.fill_between([5, 10], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey')
fill_dots = np.arange(4, 5.1, 0.125)
ax.plot(fill_dots, np.ones_like(fill_dots)*i, '.', color='tab:grey', markersize=3)
for enu, (chirp_time_oi, rise_times_oi, ag_on_off) in enumerate(zip(example_chirp_times, example_rise_times, example_ag_on_off)):
chase_on_time, chase_off_time = ag_on_off
for ct in chirp_time_oi:
ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10], [enu + .8, enu + 1.2], color='k', lw=2)
for rt in rise_times_oi:
ax.plot([rt - chase_on_time, rt - chase_on_time], [enu + .8, enu + 1.2], color='firebrick', lw=2)
stacked_agonistic_categories = np.hstack(all_agonistic_categorie)
pct_each_categorie = np.zeros(4)
for enu, cat in enumerate(range(1, 5)):
pct_each_categorie[enu] = len(stacked_agonistic_categories[stacked_agonistic_categories == cat]) / len(stacked_agonistic_categories)
ax.text(15.2, enu + 1, f'{pct_each_categorie[enu] * 100:.1f}' + ' $\%$', clip_on=False, fontsize=14, ha='left', va='center')
# plot correct spectrogram
fine_spec_plot(ax_spec, example_1_path, trial_summary, example_ag_on_off)
##########################################
ax.plot([0, 0], [0.5, 5], '--', color='k', lw=1)
ax.plot([10, 10], [0.5, 5], '--', color='k', lw=1)
ax.set_ylim(0.5, 4.5)
ax.set_xlim(-5, 15)
ax.set_yticks([1, 2, 3, 4])
# ax.set_yticklabels([r'rise$_{pre}$ $&$ chirp$_{end}$', r'only rise$_{pre}$', r'only chirp$_{end}$', 'no communication'])
ax.set_yticklabels(['A ', 'B ', 'C ', 'D '])
ax.invert_yaxis()
ax.set_xlabel('time [s]', fontsize=12)
ax.tick_params(axis='y', labelsize=20)
ax.tick_params(axis = 'x', labelsize=10)
legend_elements = [Line2D([0], [0], color='firebrick', lw=2, label=r'rise$_{lose}$'),
Line2D([0], [0], color='k', lw=2, label=r'chirp$_{lose}$'),
Patch(facecolor='tab:grey', edgecolor='w', label= 'chase event')]
ax_spec.legend(handles=legend_elements, loc='lower right', ncol=3, bbox_to_anchor=(1, 1), frameon=False, fontsize=10, facecolor='white')
ax_spec.set_ylabel('EODf [Hz]', fontsize=12)
ax.spines[['right', 'top']].set_visible(False)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'agonistic_categories' + '.png'), dpi=300)
plt.show()
### bar plot - agonistic categories counts/pct #####################################################################
fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54))
ax.bar(np.arange(4),
[len(stacked_agonistic_categories[stacked_agonistic_categories == 1]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 2]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 3]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 4])])
ax.set_xticks(np.arange(4))
ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _'])
plt.show()
# pct
pct_agon_categorie = np.zeros(shape=(len(all_agonistic_categorie), 4))
for enu, agonitic_categorie in enumerate(all_agonistic_categorie):
for cat in np.arange(4):
pct_agon_categorie[enu, cat] = len(agonitic_categorie[agonitic_categorie == cat+1]) / len(agonitic_categorie)
fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54))
ax.bar(np.arange(4), pct_agon_categorie.mean(0))
ax.errorbar(np.arange(4), pct_agon_categorie.mean(0), yerr=pct_agon_categorie.std(0), fmt='', color='k', linestyle='None')
ax.set_xticks(np.arange(4))
ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _'])
plt.show()
### marcov models plots ############################################################################################
all_marcov_matrix = np.array(all_marcov_matrix)
all_event_counts = np.array(all_event_counts)
collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0)
collective_event_counts = np.sum(all_event_counts, axis=0)
plot_transition_matrix(collective_marcov_matrix, individual_event_labels)
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(
collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100,
individual_event_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300)
plt.close()
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100,
individual_event_labels, collective_event_counts, ax, threshold=5, color_by_target=True,
title='target triggered by origin [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300)
plt.close()
for i, (marcov_matrix, event_counts) in enumerate(zip(all_marcov_matrix, all_event_counts)):
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(
marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100,
individual_event_labels, event_counts, ax, threshold=5, color_by_origin=True,
title='origin triggers target [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'),
dpi=300)
plt.close()
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(marcov_matrix / event_counts * 100,
individual_event_labels, event_counts, ax, threshold=5, color_by_target=True,
title='target triggered by origin [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'),
dpi=300)
plt.close()
####################################################################################################################
embed()
quit()
pass
if __name__ == '__main__':
main(sys.argv[1])

59
code/eval_LED.py Normal file
View File

@@ -0,0 +1,59 @@
import pandas as pd
import numpy as np
import sys
import os
import matplotlib.pyplot as plt
from IPython import embed
import cv2
import glob
def main(folder):
sr = 20000
video_path = glob.glob(os.path.join(folder, '2022*.mp4'))[0]
cap = cv2.VideoCapture(video_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
times = np.load(os.path.join(folder, 'times.npy'))
LED_idx = pd.read_csv(os.path.join(folder, 'led_idxs.csv'), sep=',', encoding = "utf-7")
led_idx = np.array(LED_idx).T[0]
led_frame = np.load(os.path.join(folder, 'LED_frames.npy'))
led_vals = np.load(os.path.join(folder, 'LED_val.npy'))
led_idx_span = led_idx[-1] - led_idx[0]
led_frame_span = led_frame[-1] - led_frame[0]
led_frame_to_idx = ((led_frame-led_frame[0]) / led_frame_span) * led_idx_span + led_idx[0]
frame_idxs = np.arange(frame_count)
frame_times = (((frame_idxs - led_frame[0]) / led_frame_span) * led_idx_span + led_idx[0]) / sr
if not os.path.exists(os.path.join(folder, 'analysis')):
os.mkdir(os.path.join(folder, 'analysis'))
np.save(os.path.join(folder, 'analysis', 'frame_times.npy'), frame_times)
########################################################################################
fig, ax = plt.subplots()
ax.plot(led_vals)
ax.plot(led_frame, np.ones(len(led_frame))*100, '.', color='firebrick')
########################################################################################
fig, ax = plt.subplots()
ax.plot(led_idx / sr, np.ones(len(led_idx)), '.', color='k')
ax.plot(led_frame_to_idx / sr, np.ones(len(led_frame_to_idx))+.1, '.', color='firebrick')
ax.plot([times[0], times[0]], [0.5, 1.5], 'k', lw=1)
ax.plot([times[-1], times[-1]], [0.5, 1.5], 'k', lw=1)
ax.plot(frame_times, np.ones(len(frame_times))*0.5)
ax.set_ylim(0, 2)
plt.show()
embed()
quit()
pass
if __name__ == '__main__':
main(sys.argv[1])

807
code/event_time_analysis.py Normal file
View File

@@ -0,0 +1,807 @@
import os
import sys
import itertools
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import pandas as pd
import scipy.stats as scp
from IPython import embed
from event_time_correlations import load_and_converete_boris_events, kde, gauss
female_color, male_color = '#e74c3c', '#3498db'
def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''):
iei = []
weighted_mean_iei = []
median_iei = []
for i in range(len(event_times)):
night_iei = np.diff(event_times[i][event_times[i] <= 3600*3])
iei.append(np.diff(event_times[i]))
if len(night_iei) == 0:
weighted_mean_iei.append(np.nan)
median_iei.append(np.nan)
else:
weighted_mean_iei.append(np.sum((night_iei) * night_iei) / np.sum(night_iei))
median_iei.append(np.median(night_iei))
weighted_mean_iei = np.array(weighted_mean_iei)
median_iei = np.array(median_iei)
fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.2, right=0.95, top=0.9, width_ratios=[5, 1], wspace=.3)
ax = []
ax.append(fig.add_subplot(gs[0, 0]))
ax.append(fig.add_subplot(gs[0, 1]))
for i in range(len(iei)):
if win_sex[i] == 'm':
if lose_sex[i] == 'm':
color, linestyle = male_color, '-'
sp = 0
else:
color, linestyle = male_color, '--'
sp = 1
else:
if lose_sex[i] == 'm':
color, linestyle = female_color, '--'
sp = 2
else:
color, linestyle = female_color, '-'
sp = 3
conv_y = np.arange(0, np.percentile(np.hstack(iei), 80), .5)
kde_array = kde(iei[i], conv_y, kernal_w=kernal_w, kernal_h=1)
# kde_array /= np.sum(kde_array)
ax[0].plot(conv_y, kde_array, zorder=2, color=color, linestyle=linestyle, lw=2)
ax[1].boxplot([weighted_mean_iei[~np.isnan(weighted_mean_iei)],
median_iei[~np.isnan(median_iei)]], positions=[0, 1], sym='', widths=0.5)
ax[0].set_xlim(conv_y[0], conv_y[-1])
ax[0].set_ylabel('KDE', fontsize=12)
ax[0].set_xlabel('inter event interval [s]', fontsize=12)
fig.suptitle(title, fontsize=12)
for a in ax:
a.tick_params(labelsize=10)
ax[1].set_xticks(np.arange(2))
ax[1].set_xticklabels([r'weighted$_{time}$', 'median'], rotation=45)
ax[1].set_ylabel('inter event interval [s]', fontsize=12)
# ax[0]
# plt.setp(ax[1].get_yticklabels(), visible=False)
# plt.setp(ax[3].get_yticklabels(), visible=False)
#
# plt.setp(ax[0].get_xticklabels(), visible=False)
# plt.setp(ax[1].get_xticklabels(), visible=False)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta', f'{title}_iei.png'), dpi=300)
plt.close()
# all_r = []
# all_p = []
# for lag in np.arange(1, 6):
# fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
# gs = gridspec.GridSpec(1, 1, left=.1, bottom=.1, right=0.95, top=0.95)
# ax = fig.add_subplot(gs[0, 0])
# plot_x = []
# plot_y = []
# for trial_iei in iei:
# plot_x.extend(trial_iei[:-lag])
# plot_y.extend(trial_iei[lag:])
# ax.plot(plot_x, plot_y, '.', color='k')
# r, p = scp.pearsonr(plot_x, plot_y)
# all_r.append(r)
# all_p.append(p)
#
# ax.set_xlim(-1, 120)
# ax.set_ylim(-1, 120)
# plt.show()
# plt.show()
return iei
def relative_rate_progression(all_event_t, title=''):
stop_t = 3*60*60
snippet_len = 15*60
snippet_starts = np.arange(0, stop_t, snippet_len)
all_snippet_ratio = []
for event_t in all_event_t:
if len(event_t) == 0:
continue
expected_snippet_count = len(event_t[event_t <= stop_t]) / (stop_t / snippet_len)
snippet_ratio = []
for s0 in snippet_starts:
snippet_count = len(event_t[(event_t >= s0) & (event_t < s0 + snippet_len)])
snippet_ratio.append(snippet_count/expected_snippet_count)
all_snippet_ratio.append(snippet_ratio)
all_snippet_ratio = np.array(all_snippet_ratio)
fig = plt.figure(figsize=(20/2.54, 12/2.54))
gs = gridspec.GridSpec(1, 1, left=.1, bottom=.1, right=0.95, top=0.95)
ax = fig.add_subplot(gs[0, 0])
plot_t = np.repeat(snippet_starts, 2)
plot_t[1::2] += snippet_len
for event_ratios in all_snippet_ratio:
plot_ratios = np.repeat(event_ratios, 2)
ax.plot(plot_t / 3600, plot_ratios, color='grey', lw=1, alpha=0.5)
# ax.plot(snippet_starts + snippet_len/2, event_ratios)
mean_ratio = np.median(all_snippet_ratio, axis=0)
plot_mean_ratio = np.repeat(mean_ratio, 2)
ax.plot(plot_t / 3600, plot_mean_ratio, color='k', lw=3)
ax.plot(plot_t / 3600, np.ones_like(plot_t), linestyle='dotted', lw=2, color='k')
ax.set_xlabel('time [h]', fontsize=12)
ax.set_ylabel('norm. event rate', fontsize=12)
ax.set_title(title)
ax.tick_params(labelsize=10)
ax.set_xlim(0, 3)
ax.set_ylim(0, 5)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta', f'{title}_progression.png'), dpi=300)
plt.close()
# plt.show()
x = np.hstack(all_snippet_ratio)
y = np.hstack(np.tile(snippet_starts, (all_snippet_ratio.shape[0], 1)))
r, p = scp.pearsonr(x, y)
print(f'Progression {title}: pearson-r={r:.2f} p={p:.3f}')
def chase_time_progression(all_ag_on_t, all_ag_off_t):
stop_t = 3*60*60
snippet_len = 15*60
snippet_starts = np.arange(0, stop_t, snippet_len)
all_snippet_chase_dur = []
for a_on, a_off in zip(all_ag_on_t, all_ag_off_t):
if len(a_on) == 0:
continue
mean_chase_dur = np.mean(a_off - a_on)
snippet_chase_dur = []
for s0 in snippet_starts:
snippet_mask = (a_on > s0) & (a_on <= s0+snippet_len)
if np.any(snippet_mask):
snippet_chase_dur.append(np.mean(a_off[snippet_mask] - a_on[snippet_mask]))
else:
snippet_chase_dur.append(np.nan)
all_snippet_chase_dur.append(np.array(snippet_chase_dur) / mean_chase_dur)
all_snippet_chase_dur = np.array(all_snippet_chase_dur)
fig = plt.figure(figsize=(20/2.54, 12/2.54))
gs = gridspec.GridSpec(1, 1, left=.1, bottom=.1, right=0.95, top=0.95)
ax = fig.add_subplot(gs[0, 0])
plot_t = np.repeat(snippet_starts, 2)
plot_t[1::2] += snippet_len
for trial_snippet_chase_dur in all_snippet_chase_dur:
plot_ratios = np.repeat(trial_snippet_chase_dur, 2)
ax.plot(plot_t / 3600, plot_ratios, color='grey', lw=1, alpha=0.5)
# ax.plot(snippet_starts + snippet_len/2, event_ratios)
mean_ratio = np.nanmean(all_snippet_chase_dur, axis=0)
plot_mean_ratio = np.repeat(mean_ratio, 2)
ax.plot(plot_t / 3600, plot_mean_ratio, color='k', lw=3)
ax.plot(plot_t / 3600, np.ones_like(plot_t), linestyle='dotted', lw=2, color='k')
ax.set_xlabel('time [h]', fontsize=12)
ax.set_ylabel('chase duration / mean(chase duration)', fontsize=12)
ax.set_title('progression chase duration ')
ax.tick_params(labelsize=10)
ax.set_xlim(0, 3)
ax.set_ylim(0, 5)
x = np.hstack(all_snippet_chase_dur)
y = np.hstack(np.tile(snippet_starts, (all_snippet_chase_dur.shape[0], 1)))
r, p = scp.pearsonr(x[~np.isnan(x)], y[~np.isnan(x)])
print(f'Progression chase duration: pearson-r={r:.2f} p={p:.3f}')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta', 'chase_duration_progression.png'), dpi=300)
plt.close()
def event_category_signal(all_event_t, all_contact_t, all_ag_on_t, all_ag_off_t, win_sex, lose_sex, event_name):
print('')
all_pre_chase_event_mask = []
all_chase_event_mask = []
all_end_chase_event_mask = []
all_after_chase_event_mask = []
all_before_contact_event_mask = []
all_after_contact_event_mask = []
all_pre_chase_time = []
all_chase_time = []
all_end_chase_time = []
all_after_chase_time = []
all_before_contact_time = []
all_after_contact_time = []
video_trial_win_sex = []
video_trial_lose_sex = []
time_tol = 5
for enu, contact_t, ag_on_t, ag_off_t, event_times in zip(
np.arange(len(all_contact_t)), all_contact_t, all_ag_on_t, all_ag_off_t, all_event_t):
if len(ag_on_t) == 0:
continue
if len(event_times) == 0:
continue
pre_chase_event_mask = np.zeros_like(event_times)
chase_event_mask = np.zeros_like(event_times)
end_chase_event_mask = np.zeros_like(event_times)
after_chase_event_mask = np.zeros_like(event_times)
video_trial_win_sex.append(win_sex[enu])
video_trial_lose_sex.append(lose_sex[enu])
for chase_on_t, chase_off_t in zip(ag_on_t, ag_off_t):
pre_chase_event_mask[(event_times >= chase_on_t - time_tol) & (event_times < chase_on_t)] = 1
chase_event_mask[(event_times >= chase_on_t) & (event_times < chase_off_t - time_tol)] = 1
end_chase_event_mask[(event_times >= chase_off_t - time_tol) & (event_times < chase_off_t)] = 1
after_chase_event_mask[(event_times >= chase_off_t) & (event_times < chase_off_t + time_tol)] = 1
all_pre_chase_event_mask.append(pre_chase_event_mask)
all_chase_event_mask.append(chase_event_mask)
all_end_chase_event_mask.append(end_chase_event_mask)
all_after_chase_event_mask.append(after_chase_event_mask)
all_pre_chase_time.append(len(ag_on_t) * time_tol)
chasing_dur = (ag_off_t - ag_on_t) - time_tol
chasing_dur[chasing_dur < 0] = 0
all_chase_time.append(np.sum(chasing_dur))
all_end_chase_time.append(len(ag_on_t) * time_tol)
all_after_chase_time.append(len(ag_on_t) * time_tol)
before_countact_event_mask = np.zeros_like(event_times)
after_countact_event_mask = np.zeros_like(event_times)
for ct in contact_t:
before_countact_event_mask[(event_times >= ct - time_tol) & (event_times < ct)] = 1
after_countact_event_mask[(event_times >= ct) & (event_times < ct + time_tol)] = 1
all_before_contact_event_mask.append(before_countact_event_mask)
all_after_contact_event_mask.append(after_countact_event_mask)
all_before_contact_time.append(len(contact_t) * time_tol)
all_after_contact_time.append(len(contact_t) * time_tol)
all_pre_chase_time = np.array(all_pre_chase_time)
all_chase_time = np.array(all_chase_time)
all_end_chase_time = np.array(all_end_chase_time)
all_after_chase_time = np.array(all_after_chase_time)
all_before_contact_time = np.array(all_before_contact_time)
all_after_contact_time = np.array(all_after_contact_time)
video_trial_win_sex = np.array(video_trial_win_sex)
video_trial_lose_sex = np.array(video_trial_lose_sex)
all_pre_chase_time_ratio = all_pre_chase_time / (3 * 60 * 60)
all_chase_time_ratio = all_chase_time / (3 * 60 * 60)
all_end_chase_time_ratio = all_end_chase_time / (3 * 60 * 60)
all_after_chase_time_ratio = all_after_chase_time / (3 * 60 * 60)
all_before_countact_time_ratio = all_before_contact_time / (3 * 60 * 60)
all_after_countact_time_ratio = all_after_contact_time / (3 * 60 * 60)
all_pre_chase_event_ratio = np.array(list(map(lambda x: np.sum(x) / len(x), all_pre_chase_event_mask)))
all_chase_event_ratio = np.array(list(map(lambda x: np.sum(x) / len(x), all_chase_event_mask)))
all_end_chase_event_ratio = np.array(list(map(lambda x: np.sum(x) / len(x), all_end_chase_event_mask)))
all_after_chase_event_ratio = np.array(list(map(lambda x: np.sum(x) / len(x), all_after_chase_event_mask)))
all_before_countact_event_ratio = np.array(list(map(lambda x: np.sum(x) / len(x), all_before_contact_event_mask)))
all_after_countact_event_ratio = np.array(list(map(lambda x: np.sum(x) / len(x), all_after_contact_event_mask)))
for x, y, name in [[all_pre_chase_event_ratio, all_pre_chase_time_ratio, 'pre chase'],
[all_chase_event_ratio, all_chase_time_ratio, 'while chase'],
[all_end_chase_event_ratio, all_end_chase_time_ratio, 'end chase'],
[all_after_chase_event_ratio, all_after_chase_time_ratio, 'after chase'],
[all_before_countact_event_ratio, all_before_countact_time_ratio, 'pre contact'],
[all_after_countact_event_ratio, all_after_countact_time_ratio, 'post contact']]:
t, p = scp.ttest_rel(x, y)
print(f'{event_name} {name}: t={t:.2f} p={p:.3f}')
fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.15, right=0.95, top=0.9)
ax = fig.add_subplot(gs[0, 0])
ax_pie = fig.add_subplot(gs[0, 1])
ax.boxplot([all_pre_chase_event_ratio / all_pre_chase_time_ratio,
all_chase_event_ratio / all_chase_time_ratio,
all_end_chase_event_ratio / all_end_chase_time_ratio,
all_after_chase_event_ratio / all_after_chase_time_ratio,
all_before_countact_event_ratio / all_before_countact_time_ratio,
all_after_countact_event_ratio / all_after_countact_time_ratio], positions=np.arange(6), sym='',
zorder=2)
ylim = list(ax.get_ylim())
ylim[0] = -.1 if ylim[0] < -.1 else ylim[0]
ylim[1] = 1.1 if ylim[1] < 1.1 else ylim[1]
##############################################################################
for sex_w, sex_l in itertools.product(['m', 'f'], repeat=2):
mec = 'k' if sex_w == sex_l else 'None'
if 'lose' in event_name:
marker = 'o'
c = male_color if sex_l == 'm' else female_color
elif "win" in event_name:
marker = 'p'
c = male_color if sex_w == 'm' else female_color
else:
print('error')
embed()
quit()
values = np.array(all_pre_chase_event_ratio / all_pre_chase_time_ratio)[
(video_trial_win_sex == sex_w) & (video_trial_lose_sex == sex_l)]
ax.plot(np.ones_like(values) * 0, values, marker=marker, linestyle='None', color=c, mec=mec, markersize=8,
zorder=1)
values = np.array(all_chase_event_ratio / all_chase_time_ratio)[
(video_trial_win_sex == sex_w) & (video_trial_lose_sex == sex_l)]
ax.plot(np.ones_like(values) * 1, values, marker=marker, linestyle='None', color=c, mec=mec, markersize=8,
zorder=1)
values = np.array(all_end_chase_event_ratio / all_end_chase_time_ratio)[
(video_trial_win_sex == sex_w) & (video_trial_lose_sex == sex_l)]
ax.plot(np.ones_like(values) * 2, values, marker=marker, linestyle='None', color=c, mec=mec, markersize=8,
zorder=1)
values = np.array(all_after_chase_event_ratio / all_after_chase_time_ratio)[
(video_trial_win_sex == sex_w) & (video_trial_lose_sex == sex_l)]
ax.plot(np.ones_like(values) * 3, values, marker=marker, linestyle='None', color=c, mec=mec, markersize=8,
zorder=1)
values = np.array(all_before_countact_event_ratio / all_before_countact_time_ratio)[
(video_trial_win_sex == sex_w) & (video_trial_lose_sex == sex_l)]
ax.plot(np.ones_like(values) * 4, values, marker=marker, linestyle='None', color=c, mec=mec, markersize=8,
zorder=1)
values = np.array(all_after_countact_event_ratio / all_after_countact_time_ratio)[
(video_trial_win_sex == sex_w) & (video_trial_lose_sex == sex_l)]
ax.plot(np.ones_like(values) * 5, values, marker=marker, linestyle='None', color=c, mec=mec, markersize=8,
zorder=1)
##############################################################################
ax.plot(np.arange(7) - 1, np.ones(7), linestyle='dotted', lw=2, color='k')
ax.set_xlim(-0.5, 5.5)
ax.set_ylim(ylim[0], ylim[1])
ax.set_ylabel(r'rel. count$_{event}$ / rel. time$_{event}$', fontsize=12)
ax.set_xticks(np.arange(6))
ax.set_xticklabels([r'chase$_{before}$', r'chasing', r'chase$_{end}$', r'chase$_{after}$', 'contact$_{before}$',
'contact$_{after}$'], rotation=45)
ax.tick_params(labelsize=10)
fig.suptitle(f'{event_name}: n={len(np.hstack(all_event_t))}')
###############################################
flat_pre_chase_event_mask = np.hstack(all_pre_chase_event_mask)
flat_chase_event_mask = np.hstack(all_chase_event_mask)
flat_end_chase_event_mask = np.hstack(all_end_chase_event_mask)
flat_after_chase_event_mask = np.hstack(all_after_chase_event_mask)
flat_before_countact_event_mask = np.hstack(all_before_contact_event_mask)
flat_after_countact_event_mask = np.hstack(all_after_contact_event_mask)
flat_pre_chase_event_mask[(flat_before_countact_event_mask == 1) | (flat_after_countact_event_mask == 1)] = 0
flat_chase_event_mask[(flat_before_countact_event_mask == 1) | (flat_after_countact_event_mask == 1)] = 0
flat_end_chase_event_mask[(flat_before_countact_event_mask == 1) | (flat_after_countact_event_mask == 1)] = 0
flat_after_chase_event_mask[(flat_before_countact_event_mask == 1) | (flat_after_countact_event_mask == 1)] = 0
event_context_values = [np.sum(flat_pre_chase_event_mask) / len(flat_pre_chase_event_mask),
np.sum(flat_chase_event_mask) / len(flat_chase_event_mask),
np.sum(flat_end_chase_event_mask) / len(flat_end_chase_event_mask),
np.sum(flat_after_chase_event_mask) / len(flat_after_chase_event_mask),
np.sum(flat_before_countact_event_mask) / len(flat_before_countact_event_mask),
np.sum(flat_after_countact_event_mask) / len(flat_after_countact_event_mask)]
event_context_values.append(1 - np.sum(event_context_values))
time_context_values = [np.sum(all_pre_chase_time), np.sum(all_chase_time), np.sum(all_end_chase_time),
np.sum(all_after_chase_time), np.sum(all_before_contact_time),
np.sum(all_after_contact_time)]
time_context_values.append(len(all_pre_chase_time) * 3 * 60 * 60 - np.sum(time_context_values))
time_context_values /= np.sum(time_context_values)
# fig, ax = plt.subplots(figsize=(12/2.54,12/2.54))
size = 0.3
outer_colors = ['tab:red', 'tab:orange', 'yellow', 'tab:green', 'k', 'tab:brown', 'tab:grey']
ax_pie.pie(event_context_values, radius=1, colors=outer_colors,
wedgeprops=dict(width=size, edgecolor='w'), startangle=90, center=(0, 1))
ax_pie.pie(time_context_values, radius=1 - size, colors=outer_colors,
wedgeprops=dict(width=size, edgecolor='w', alpha=.6), startangle=90, center=(0, 1))
ax_pie.set_title(r'event context')
legend_elements = [Patch(facecolor='tab:red', edgecolor='w', label='%.1f' % (event_context_values[0] * 100) + '%'),
Patch(facecolor='tab:orange', edgecolor='w',
label='%.1f' % (event_context_values[1] * 100) + '%'),
Patch(facecolor='yellow', edgecolor='w', label='%.1f' % (event_context_values[2] * 100) + '%'),
Patch(facecolor='tab:green', edgecolor='w',
label='%.1f' % (event_context_values[3] * 100) + '%'),
Patch(facecolor='k', edgecolor='w', label='%.1f' % (event_context_values[4] * 100) + '%'),
Patch(facecolor='tab:brown', edgecolor='w',
label='%.1f' % (event_context_values[5] * 100) + '%'),
Patch(facecolor='tab:red', alpha=0.6, edgecolor='w',
label='%.1f' % (time_context_values[0] * 100) + '%'),
Patch(facecolor='tab:orange', alpha=0.6, edgecolor='w',
label='%.1f' % (time_context_values[1] * 100) + '%'),
Patch(facecolor='yellow', alpha=0.6, edgecolor='w',
label='%.1f' % (time_context_values[2] * 100) + '%'),
Patch(facecolor='tab:green', alpha=0.6, edgecolor='w',
label='%.1f' % (time_context_values[3] * 100) + '%'),
Patch(facecolor='k', alpha=0.6, edgecolor='w',
label='%.1f' % (time_context_values[4] * 100) + '%'),
Patch(facecolor='tab:brown', alpha=0.6, edgecolor='w',
label='%.1f' % (time_context_values[5] * 100) + '%')]
ax_pie.legend(handles=legend_elements, loc='lower right', ncol=2, bbox_to_anchor=(1.15, -0.25), frameon=False,
fontsize=9)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr', f'{event_name}_categories.png'),
dpi=300)
plt.close()
# plt.show()
def main(base_path):
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta'))
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr'))
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
trial_mask = chirp_notes['good'] == 1
### data processing #######################
all_rise_times_lose = []
all_rise_times_win = []
all_chirp_times_lose = []
all_chirp_times_win = []
win_sex = []
lose_sex = []
all_contact_t = []
all_ag_on_t = []
all_ag_off_t = []
for index, trial in trial_summary.iterrows():
print(index, len(trial_summary))
got_boris = False
trial_path = os.path.join(base_path, trial['recording'])
if trial['group'] < 3:
continue
if trial['draw'] == 1:
continue
if os.path.exists(os.path.join(trial_path, 'led_idxs.csv')):
got_boris = True
if os.path.exists(os.path.join(trial_path, 'LED_frames.npy')):
got_boris = True
ids = np.load(os.path.join(trial_path, 'analysis', 'ids.npy'))
times = np.load(os.path.join(trial_path, 'times.npy'))
sorter = -1 if trial['win_ID'] != ids[0] else 1
### event times --> BORIS behavior
if got_boris:
contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames = \
load_and_converete_boris_events(trial_path, trial['recording'], sr=20_000)
all_contact_t.append(contact_t_GRID)
all_ag_on_t.append(ag_on_off_t_GRID[:, 0])
all_ag_off_t.append(ag_on_off_t_GRID[:, 1])
else:
all_contact_t.append(np.array([]))
all_ag_on_t.append(np.array([]))
all_ag_off_t.append(np.array([]))
### communication
if not os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')):
continue
chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy'))
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]]
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]]
all_rise_times_lose.append(rise_times[1])
all_rise_times_win.append(rise_times[0])
if trial_mask[index]:
all_chirp_times_lose.append(chirp_times[1])
all_chirp_times_win.append(chirp_times[0])
else:
all_chirp_times_lose.append(np.array([]))
all_chirp_times_win.append(np.array([]))
win_sex.append(trial['sex_win'])
lose_sex.append(trial['sex_lose'])
win_sex = np.array(win_sex)
lose_sex = np.array(lose_sex)
### inter event intervalls ###
inter_chirp_interval_lose = iei_analysis(all_chirp_times_lose, win_sex, lose_sex, kernal_w=1, title=r'chirps$_{lose}$')
_ = iei_analysis(all_chirp_times_win, win_sex, lose_sex, kernal_w=1, title=r'chirps$_{win}$')
_ = iei_analysis(all_rise_times_lose, win_sex, lose_sex, kernal_w=5, title=r'rises$_{lose}$')
_ = iei_analysis(all_rise_times_win, win_sex, lose_sex, kernal_w=50, title=r'rises$_{win}$')
fig, ax = plt.subplots()
n, bin_edges = np.histogram(np.hstack(inter_chirp_interval_lose), bins = np.arange(0, 20, 0.05))
ax.bar(bin_edges[:-1] + (bin_edges[1] - bin_edges[0])/2, n/np.sum(n)/(bin_edges[1] - bin_edges[0]), width=(bin_edges[1] - bin_edges[0]))
ylim = ax.get_ylim()
med_ici = np.nanmedian(np.hstack(inter_chirp_interval_lose))
ax.plot([med_ici, med_ici], [ylim[0], ylim[1]], '-k', lw=2)
plt.show()
chirp_dt_burst_th = med_ici
print(np.nanpercentile(np.hstack(inter_chirp_interval_lose), (50, 75, 95)))
# chirp_dt_burst_th = bin_edges[np.argmax(n)] - (bin_edges[1] - bin_edges[0]) / 2
burst_chirp_mask = []
for enu, ici in enumerate(inter_chirp_interval_lose):
if len(ici) >= 1:
trial_burst_chirp_mask = np.zeros_like(ici)
trial_burst_chirp_mask[ici < chirp_dt_burst_th] = 1
trial_burst_chirp_mask[1:][(ici[:-1] < chirp_dt_burst_th) & (ici[1:] >= chirp_dt_burst_th)] = 2
last = 2 if trial_burst_chirp_mask[-1] == 1 else 0
trial_burst_chirp_mask = np.append(trial_burst_chirp_mask, np.array(last))
burst_chirp_mask.append(trial_burst_chirp_mask)
else:
burst_chirp_mask.append(np.array([]))
fig = plt.figure(figsize=(21/2.54, 19/2.54))
gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.1, right=0.95, top=0.95)
ax = []
ax.append(fig.add_subplot(gs[0, 0]))
ax.append(fig.add_subplot(gs[0, 1]))
all_chirps_in_burst_distro = []
for i in range(len(burst_chirp_mask)):
ax[0].plot(all_chirp_times_lose[i], np.ones_like(all_chirp_times_lose[i]) * i, '|', markersize=12, color='grey')
if len(burst_chirp_mask[i]) == 0:
continue
chirp_idx_burst_start = np.arange(len(all_chirp_times_lose[i])-1)[(burst_chirp_mask[i][:-1] != 1) & (burst_chirp_mask[i][1:] == 1)] + 1
if burst_chirp_mask[i][0] == 1:
chirp_idx_burst_start = np.append(0, chirp_idx_burst_start)
chirp_idx_burst_end = np.arange(len(all_chirp_times_lose[i]))[(burst_chirp_mask[i] == 2)]
chirp_idx_burst_start = np.array(chirp_idx_burst_start, dtype=int)
chirp_idx_burst_end = np.array(chirp_idx_burst_end, dtype=int)
chirps_in_burst = chirp_idx_burst_end - chirp_idx_burst_start + 1
if len(chirps_in_burst) == 0:
continue
chirps_in_burst_distro = np.zeros(np.max(chirps_in_burst))
for j in range(np.max(chirps_in_burst)):
if j == 0:
chirps_in_burst_distro[j] = len(burst_chirp_mask[i][burst_chirp_mask[i] == 0])
else:
chirps_in_burst_distro[j] = len(chirps_in_burst[chirps_in_burst == j + 1])
for cbs, cbe in zip(all_chirp_times_lose[i][chirp_idx_burst_start], all_chirp_times_lose[i][chirp_idx_burst_end]):
ax[0].plot([cbs, cbe], [i, i], '-k', lw=2)
all_chirps_in_burst_distro.append(chirps_in_burst_distro)
max_chirps_in_burst = np.max(list(map(lambda x: len(x), all_chirps_in_burst_distro)))
collective_chirps_in_burst = np.zeros((len(all_chirps_in_burst_distro), max_chirps_in_burst))
for trial in range(len(all_chirps_in_burst_distro)):
collective_chirps_in_burst[trial, :len(all_chirps_in_burst_distro[trial])] = all_chirps_in_burst_distro[trial]
ax[1].bar(np.arange(collective_chirps_in_burst.shape[1])+1, collective_chirps_in_burst.sum(0))
ax[1].plot(np.arange(collective_chirps_in_burst.shape[1])+1,
collective_chirps_in_burst.sum(0) * (np.arange(collective_chirps_in_burst.shape[1])+1), color='firebrick', lw=2)
plt.show()
### event progressions ###
print('')
relative_rate_progression(all_chirp_times_lose, title=r'chirp$_{lose}$')
relative_rate_progression(all_chirp_times_win, title=r'chirp$_{win}$')
relative_rate_progression(all_rise_times_lose, title=r'rises$_{lose}$')
relative_rate_progression(all_rise_times_win, title=r'rises$_{win}$')
relative_rate_progression(all_contact_t, title=r'contact')
relative_rate_progression(all_ag_on_t, title=r'chasing')
chase_time_progression(all_ag_on_t, all_ag_off_t)
### event category signals ###
for all_event_t, event_name in zip([all_chirp_times_lose, all_chirp_times_win, all_rise_times_lose, all_rise_times_win],
[r'chirps$_{lose}$', r'chirps$_{win}$', r'rises$_{lose}$', r'rises$_{win}$']):
event_category_signal(all_event_t, all_contact_t, all_ag_on_t, all_ag_off_t, win_sex, lose_sex, event_name)
#################################
chase_dur = []
chase_chirp_count = []
dt_start_first_chirp = []
dt_end_first_chirp = []
dt_start_all_chirp = []
dt_end_all_chirp = []
all_chirp_mask = []
chase_dur_all_chirp = []
for ag_on_t, ag_off_t, chirp_times_lose, trial_chirp_burst_mask in \
zip(all_ag_on_t, all_ag_off_t, all_chirp_times_lose, burst_chirp_mask):
if len(chirp_times_lose) == 0:
continue
for a_on, a_off in zip(ag_on_t, ag_off_t):
chase_dur.append(a_off - a_on)
chirp_t_oi = chirp_times_lose[(chirp_times_lose > a_on) & (chirp_times_lose <= a_off)]
chirp_t_oi_mask = trial_chirp_burst_mask[(chirp_times_lose > a_on) & (chirp_times_lose <= a_off)]
chase_chirp_count.append(len(chirp_t_oi))
if len(chirp_t_oi) >= 1:
dt_start_first_chirp.append(chirp_t_oi[0] - a_on)
dt_end_first_chirp.append(a_off - chirp_t_oi[0])
dt_start_all_chirp.extend(chirp_t_oi - a_on)
dt_end_all_chirp.extend(a_off - chirp_t_oi)
all_chirp_mask.extend(chirp_t_oi_mask)
chase_dur_all_chirp.extend(np.ones_like(chirp_t_oi) * (a_off - a_on))
else:
dt_start_first_chirp.append(np.nan)
dt_end_first_chirp.append(np.nan)
dt_start_first_chirp = np.array(dt_start_first_chirp)
dt_end_first_chirp = np.array(dt_end_first_chirp)
dt_start_all_chirp = np.array(dt_start_all_chirp)
dt_end_all_chirp = np.array(dt_end_all_chirp)
all_chirp_mask = np.array(all_chirp_mask)
chase_dur_all_chirp = np.array(chase_dur_all_chirp)
chase_chirp_count = np.array(chase_chirp_count)
chase_dur = np.array(chase_dur)
chirp_rate = chase_chirp_count / chase_dur
chase_dur_per_chirp_count = []
positions = np.arange(np.max(chase_chirp_count)+1)
for cc in positions:
chase_dur_per_chirp_count.append([])
chase_dur_per_chirp_count[-1].extend(chase_dur[chase_chirp_count == cc])
################################################
chase_dur_pct99 = np.percentile(chase_dur, 99)
chase_dur_bins = np.arange(0, chase_dur_pct99+1, 2.5)
chase_dur_count_above_th = np.zeros_like(chase_dur_bins[:-1])
for enu, chase_dur_th in enumerate(chase_dur_bins[:-1]):
chase_dur_count_above_th[enu] = len(chase_dur[chase_dur >= chase_dur_th])
################################################
fig = plt.figure(figsize=(21/2.54, 28/2.54))
gs = gridspec.GridSpec(3, 2, left=.15, bottom=0.1, right=0.95, top=0.95)
ax = []
ax.append(fig.add_subplot(gs[0, 0]))
ax.append(fig.add_subplot(gs[0, 1], sharex=ax[0]))
ax.append(fig.add_subplot(gs[1, 0], sharex=ax[0]))
ax.append(fig.add_subplot(gs[1, 1], sharex=ax[0]))
ax.append(fig.add_subplot(gs[2, 0], sharex=ax[0]))
ax.append(fig.add_subplot(gs[2, 1], sharex=ax[0]))
ax[0].plot(chase_dur, chase_chirp_count, '.')
ax[0].boxplot(chase_dur_per_chirp_count, positions=positions, vert=False, sym='')
ax[1].plot(chase_dur, np.array(chase_chirp_count) / np.array(chase_dur), '.')
ax[2].plot(chase_dur, dt_start_first_chirp, '.')
ax[3].plot(chase_dur, dt_end_first_chirp, '.')
ax[2].plot([0, 60], [0, 60], '-k', lw=1)
ax[3].plot([0, 60], [0, 60], '-k', lw=1)
n, _ = np.histogram(dt_start_first_chirp, bins=chase_dur_bins)
n = n / np.sum(n) / (chase_dur_bins[1] - chase_dur_bins[0])
n = n / chase_dur_count_above_th
n = n / np.max(n) * chase_dur_pct99
ax[2].barh(chase_dur_bins[:-1] + (chase_dur_bins[1] - chase_dur_bins[0]) / 2, n,
height=(chase_dur_bins[1] - chase_dur_bins[0]) * 0.8, color='firebrick', alpha=0.5, zorder=2)
n, _ = np.histogram(dt_end_first_chirp, bins=chase_dur_bins)
n = n / np.sum(n) / (chase_dur_bins[1] - chase_dur_bins[0])
n = n / chase_dur_count_above_th
n = n / np.max(n) * chase_dur_pct99
ax[3].barh(chase_dur_bins[:-1] + (chase_dur_bins[1] - chase_dur_bins[0]) / 2, n,
height=(chase_dur_bins[1] - chase_dur_bins[0]) * 0.8, color='firebrick', alpha=0.5, zorder=2)
ax[3].invert_yaxis()
ax[2].set_xlim(right=chase_dur_pct99 + 2)
ax[2].set_ylim(top=chase_dur_pct99 + 2)
ax[3].set_xlim(right=chase_dur_pct99 + 2)
ax[3].set_ylim(bottom=chase_dur_pct99 + 2)
ax[4].plot(chase_dur_all_chirp[all_chirp_mask == 0], dt_start_all_chirp[all_chirp_mask == 0], '.', color='cornflowerblue', alpha = 0.5)
ax[4].plot(chase_dur_all_chirp[all_chirp_mask != 0], dt_start_all_chirp[all_chirp_mask != 0], '.', color='k', alpha = 0.5)
ax[5].plot(chase_dur_all_chirp[all_chirp_mask == 0], dt_end_all_chirp[all_chirp_mask == 0], '.', color='cornflowerblue', alpha = 0.5)
ax[5].plot(chase_dur_all_chirp[all_chirp_mask != 0], dt_end_all_chirp[all_chirp_mask != 0], '.', color='k', alpha = 0.5)
ax[4].plot([0, 60], [0, 60], '-k', lw=1)
ax[5].plot([0, 60], [0, 60], '-k', lw=1)
n, _ = np.histogram(dt_start_all_chirp, bins=chase_dur_bins)
n = n / np.sum(n) / (chase_dur_bins[1] - chase_dur_bins[0])
n = n / chase_dur_count_above_th
n = n / np.max(n) * chase_dur_pct99
ax[4].barh(chase_dur_bins[:-1] + (chase_dur_bins[1] - chase_dur_bins[0])/4, n, height=(chase_dur_bins[1] - chase_dur_bins[0])*0.4, color='firebrick', alpha=0.52, zorder=2)
n, _ = np.histogram(dt_start_all_chirp[all_chirp_mask != 0], bins=chase_dur_bins)
n = n / np.sum(n) / (chase_dur_bins[1] - chase_dur_bins[0])
n = n / chase_dur_count_above_th
n = n / np.max(n) * chase_dur_pct99
ax[4].barh(chase_dur_bins[:-1] + (chase_dur_bins[1] - chase_dur_bins[0])/4*3, n, height=(chase_dur_bins[1] - chase_dur_bins[0])*0.4, color='k', alpha=0.52, zorder=2)
n, _ = np.histogram(dt_end_all_chirp, bins=chase_dur_bins)
n = n / np.sum(n) / (chase_dur_bins[1] - chase_dur_bins[0])
n = n / chase_dur_count_above_th
n = n / np.max(n) * chase_dur_pct99
ax[5].barh(chase_dur_bins[:-1] + (chase_dur_bins[1] - chase_dur_bins[0])/4, n, height=(chase_dur_bins[1] - chase_dur_bins[0])*0.4, color='firebrick', alpha=0.5, zorder=2)
n, _ = np.histogram(dt_end_all_chirp[all_chirp_mask != 0], bins=chase_dur_bins)
n = n / np.sum(n) / (chase_dur_bins[1] - chase_dur_bins[0])
n = n / chase_dur_count_above_th
n = n / np.max(n) * chase_dur_pct99
ax[5].barh(chase_dur_bins[:-1] + (chase_dur_bins[1] - chase_dur_bins[0])/4*3, n, height=(chase_dur_bins[1] - chase_dur_bins[0])*0.4, color='k', alpha=0.5, zorder=2)
ax[5].invert_yaxis()
ax[4].set_xlim(right=chase_dur_pct99+2)
ax[4].set_ylim(top=chase_dur_pct99+2)
ax[5].set_xlim(right=chase_dur_pct99+2)
ax[5].set_ylim(bottom=chase_dur_pct99+2)
ax[4].set_xlabel(r'chase$_{duration}$ [s]', fontsize=12)
ax[5].set_xlabel(r'chase$_{duration}$ [s]', fontsize=12)
ax[0].set_ylabel('chirps [n]', fontsize=12)
ax[1].set_ylabel('chirp rate [Hz]', fontsize=12)
ax[2].set_ylabel(r'$\Delta$t chase$_{on}$ - chirp$_{0}$', fontsize=12)
ax[3].set_ylabel(r'$\Delta$t chirp$_{0}$ - chase$_{off}$', fontsize=12)
ax[4].set_ylabel(r'$\Delta$t chase$_{on}$ - chirps', fontsize=12)
ax[5].set_ylabel(r'$\Delta$t chirps - chase$_{off}$', fontsize=12)
plt.show()
embed()
quit()
if __name__ == '__main__':
main(sys.argv[1])

View File

@@ -0,0 +1,560 @@
import os
import sys
import argparse
import time
import itertools
import numpy as np
try:
import cupy as cp
except ImportError:
import numpy as cp
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pandas as pd
from IPython import embed
from tqdm import tqdm
female_color, male_color = '#e74c3c', '#3498db'
def load_and_converete_boris_events(trial_path, recording, sr):
def converte_video_frames_to_grid_idx(event_frames, led_frames, led_idx):
event_idx_grid = (event_frames - led_frames[0]) / (led_frames[-1] - led_frames[0]) * (led_idx[-1] - led_idx[0]) + led_idx[0]
return event_idx_grid
# idx in grid-recording
led_idx = pd.read_csv(os.path.join(trial_path, 'led_idxs.csv'), header=None).iloc[:, 0].to_numpy()
# frames where LED gets switched on
led_frames = np.load(os.path.join(trial_path, 'LED_frames.npy'))
times, behavior, t_ag_on_off, t_contact, video_FPS = load_boris(trial_path, recording)
contact_frame = np.array(np.round(t_contact * video_FPS), dtype=int)
ag_on_off_frame = np.array(np.round(t_ag_on_off * video_FPS), dtype=int)
# led_t_GRID = led_idx / sr
contact_t_GRID = converte_video_frames_to_grid_idx(contact_frame, led_frames, led_idx) / sr
ag_on_off_t_GRID = converte_video_frames_to_grid_idx(ag_on_off_frame, led_frames, led_idx) / sr
return contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames
def load_boris(trial_path, recording):
boris_file = '-'.join(recording.split('-')[:3]) + '.csv'
data = pd.read_csv(os.path.join(trial_path, boris_file))
times = data['Start (s)']
behavior = data['Behavior']
t_ag_on = times[behavior == 0]
t_ag_off = times[behavior == 1]
t_ag_on_off = []
for t in t_ag_on:
t1 = np.array(t_ag_off)[t_ag_off > t]
if len(t1) >= 1:
t_ag_on_off.append(np.array([t, t1[0]]))
t_contact = times[behavior == 2]
return times, behavior, np.array(t_ag_on_off), t_contact.to_numpy(), data['FPS'][0]
def gauss(t, shift, sigma, size, norm = False):
if not hasattr(shift, '__len__'):
g = np.exp(-((t - shift) / sigma) ** 2 / 2) * size
if norm:
g /= np.sum(g)
return g
else:
t = np.array([t, ] * len(shift))
res = np.exp(-((t.transpose() - shift).transpose() / sigma) ** 2 / 2) * size
return res
def event_centered_times(centered_event_times, surrounding_event_times, max_dt = np.inf):
event_dt = []
for Cevent_t in centered_event_times:
Cdt = np.array(surrounding_event_times - Cevent_t)
event_dt.extend(Cdt[np.abs(Cdt) <= max_dt])
return np.array(event_dt)
def kde(event_dt, conv_t, kernal_w = 1, kernal_h = 0.2):
conv_array = np.zeros(len(conv_t))
for e in event_dt:
conv_array += gauss(conv_t, e, kernal_w, kernal_h)
return conv_array
def permutation_kde(event_dt, conv_t, repetitions = 2000, max_mem_use_GB = 4, kernal_w = 1, kernal_h = 0.2):
def chunk_permutation(select_event_dt, conv_tt, n_chuck, max_jitter, kernal_w, kernal_h):
# array.shape = (120, 100, 15486) = (len(conv_t), repetitions, len(event_dt))
# event_dt_perm = cp.tile(event_dt, (len(conv_t), repetitions, 1))
event_dt_perm = cp.tile(select_event_dt, (len(conv_tt), n_chuck, 1))
jitter = cp.random.uniform(-max_jitter, max_jitter, size=(event_dt_perm.shape[1], event_dt_perm.shape[2]))
jitter = cp.expand_dims(jitter, axis=0)
event_dt_perm += jitter
# conv_t_perm = cp.tile(conv_tt, (1, repetitions, len(event_dt)))
gauss_3d = cp.exp(-((conv_tt - event_dt_perm) / kernal_w) ** 2 / 2) * kernal_h
kde_3d = cp.sum(gauss_3d, axis = 2).transpose()
try:
kde_3d_numpy = cp.asnumpy(kde_3d)
del event_dt_perm, gauss_3d, kde_3d
return kde_3d_numpy
except AttributeError:
del event_dt_perm, gauss_3d
return kde_3d
t0 = time.time()
max_jitter = float(2*cp.max(conv_t))
select_event_dt = event_dt[np.abs(event_dt) <= float(cp.max(conv_t)) + max_jitter*2]
# conv_t = cp.arange(-max_dt, max_dt, 1)
conv_tt = cp.reshape(conv_t, (len(conv_t), 1, 1))
chunk_size = int(np.floor(max_mem_use_GB / (select_event_dt.nbytes * conv_t.size / 1e9)))
chunk_collector =[]
for _ in range(repetitions // chunk_size):
chunk_boot_KDE = chunk_permutation(select_event_dt, conv_tt, chunk_size, max_jitter, kernal_w, kernal_h)
chunk_collector.extend(chunk_boot_KDE)
# # array.shape = (120, 100, 15486) = (len(conv_t), repetitions, len(event_dt))
# # event_dt_perm = cp.tile(event_dt, (len(conv_t), repetitions, 1))
# event_dt_perm = cp.tile(event_dt, (len(conv_t), chunk_size, 1))
# jitter = np.random.uniform(-max_jitter, max_jitter, size=(event_dt_perm.shape[1], event_dt_perm.shape[2]))
# jitter = np.expand_dims(jitter, axis=0)
#
# event_dt_perm += jitter
# # conv_t_perm = cp.tile(conv_tt, (1, repetitions, len(event_dt)))
#
# gauss_3d = cp.exp(-((conv_tt - event_dt_perm) / kernal_w) ** 2 / 2) * kernal_h
# kde_3d = cp.sum(gauss_3d, axis = 2).transpose()
# try:
# kde_3d_numpy = cp.asnumpy(kde_3d)
# chunk_collector.extend(kde_3d_numpy)
# except AttributeError:
# chunk_collector.extend(kde_3d)
# del event_dt_perm, gauss_3d, kde_3d
chunk_boot_KDE = chunk_permutation(select_event_dt, conv_tt, repetitions % chunk_size, max_jitter, kernal_w, kernal_h)
chunk_collector.extend(chunk_boot_KDE)
chunk_collector = np.array(chunk_collector)
# ToDo: this works but is incorrect i think
# chunk_collector /= np.sum(chunk_collector, axis=1).reshape(chunk_collector.shape[0], 1)
print(f'bootstrap with {repetitions:.0f} repetitions took {time.time() - t0:.2f}s.')
# fig, ax = plt.subplots()
# for i in range(len(chunk_collector)):
# ax.plot(cp.asnumpy(conv_t), chunk_collector[i])
return chunk_collector
def jackknife_kde(event_dt, conv_t, repetitions = 2000, max_mem_use_GB = 2, jack_pct = 0.9, kernal_w = 1, kernal_h = 0.2):
def chunk_jackknife(select_event_dt, conv_tt, n_chuck, jack_pct, kernal_w, kernal_h):
event_dt_rep = cp.tile(select_event_dt, (n_chuck, 1))
idx = cp.random.rand(*event_dt_rep.shape).argsort(1)[:, :int(event_dt_rep.shape[-1]*jack_pct)]
event_dt_jk = event_dt_rep[cp.arange(event_dt_rep.shape[0])[:, None], idx]
event_dt_jk_full = cp.tile(event_dt_jk, (len(conv_tt), 1, 1))
gauss_3d = cp.exp(-((conv_tt - event_dt_jk_full) / kernal_w) ** 2 / 2) * kernal_h
kde_3d = cp.sum(gauss_3d, axis = 2).transpose()
try:
kde_3d_numpy = cp.asnumpy(kde_3d)
del event_dt_rep, idx, event_dt_jk, event_dt_jk_full, gauss_3d, kde_3d
return kde_3d_numpy
except AttributeError:
del event_dt_rep, idx, event_dt_jk, event_dt_jk_full, gauss_3d
return kde_3d
t0 = time.time()
# max_jitter = 2*max_dt
select_event_dt = event_dt[np.abs(event_dt) <= float(cp.max(conv_t)) * 2]
if len(select_event_dt) == 0:
return np.zeros((repetitions, len(conv_t)))
# conv_t = cp.arange(-max_dt, max_dt, 1)
conv_tt = cp.reshape(conv_t, (len(conv_t), 1, 1))
chunk_size = int(np.floor(max_mem_use_GB / (select_event_dt.nbytes * jack_pct * conv_t.size / 1e9)))
chunk_collector =[]
for _ in range(repetitions // chunk_size):
chunk_jackknife_KDE = chunk_jackknife(select_event_dt, conv_tt, chunk_size, jack_pct, kernal_w, kernal_h)
chunk_collector.extend(chunk_jackknife_KDE)
del chunk_jackknife_KDE
# # array.shape = (120, 100, 15486) = (len(conv_t), repetitions, len(event_dt))
# # event_dt_perm = cp.tile(event_dt, (len(conv_t), repetitions, 1))
# event_dt_perm = cp.tile(event_dt, (len(conv_t), chunk_size, 1))
# jitter = np.random.uniform(-max_jitter, max_jitter, size=(event_dt_perm.shape[1], event_dt_perm.shape[2]))
# jitter = np.expand_dims(jitter, axis=0)
#
# event_dt_perm += jitter
# # conv_t_perm = cp.tile(conv_tt, (1, repetitions, len(event_dt)))
#
# gauss_3d = cp.exp(-((conv_tt - event_dt_perm) / kernal_w) ** 2 / 2) * kernal_h
# kde_3d = cp.sum(gauss_3d, axis = 2).transpose()
# try:
# kde_3d_numpy = cp.asnumpy(kde_3d)
# chunk_collector.extend(kde_3d_numpy)
# except AttributeError:
# chunk_collector.extend(kde_3d)
# del event_dt_perm, gauss_3d, kde_3d
chunk_jackknife_KDE = chunk_jackknife(select_event_dt, conv_tt, repetitions % chunk_size, jack_pct, kernal_w, kernal_h)
chunk_collector.extend(chunk_jackknife_KDE)
del chunk_jackknife_KDE
chunk_collector = np.array(chunk_collector)
print(f'jackknife with {repetitions:.0f} repetitions took {time.time() - t0:.2f}s.')
return chunk_collector
def single_kde(event_dt, conv_t, kernal_w = 1, kernal_h = 0.2):
single_kdes = cp.zeros((len(event_dt), len(conv_t)))
for enu, e_dt in enumerate(event_dt):
Ce_dt = e_dt[np.abs(e_dt) <= float(cp.max(conv_t)) * 2]
conv_tt = cp.reshape(conv_t, (len(conv_t), 1))
Ce_dt_tile = cp.tile(Ce_dt, (len(conv_tt), 1))
gauss_3d = cp.exp(-((conv_tt - Ce_dt_tile) / kernal_w) ** 2 / 2) * kernal_h
single_kdes[enu] = cp.sum(gauss_3d, axis=1)
return cp.asnumpy(single_kdes)
def main(base_path):
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr'))
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
# trial_summary = trial_summary[chirp_notes['good'] == 1]
trial_mask = chirp_notes['good'] == 1
# ToDo: do chirp on chirp and rise on rise
lose_chrips_centered_on_ag_off_t = []
lose_chrips_centered_on_ag_on_t = []
lose_chrips_centered_on_contact_t = []
lose_chrips_centered_on_win_rises = []
lose_chrips_centered_on_win_chirp = []
lose_chirps_centered_on_lose_rises = []
win_chrips_centered_on_ag_off_t = []
win_chrips_centered_on_ag_on_t = []
win_chrips_centered_on_contact_t = []
win_chrips_centered_on_lose_rises = []
win_chrips_centered_on_lose_chirp = []
win_chirps_centered_on_win_rises = []
lose_rises_centered_on_ag_off_t = []
lose_rises_centered_on_ag_on_t = []
lose_rises_centered_on_contact_t = []
lose_rises_centered_on_win_chirps = []
win_rises_centered_on_ag_off_t = []
win_rises_centered_on_ag_on_t = []
win_rises_centered_on_contact_t = []
win_rises_centered_on_lose_chirps = []
ag_off_centered_on_ag_on = []
lose_chirp_count = []
win_chirp_count = []
lose_rises_count = []
win_rises_count = []
chase_count = []
contact_count = []
sex_win = []
sex_lose = []
for index, trial in tqdm(trial_summary.iterrows()):
trial_path = os.path.join(base_path, trial['recording'])
if trial['group'] < 5:
continue
if not os.path.exists(os.path.join(trial_path, 'led_idxs.csv')):
continue
if not os.path.exists(os.path.join(trial_path, 'LED_frames.npy')):
continue
if trial['draw'] == 1:
continue
ids = np.load(os.path.join(trial_path, 'analysis', 'ids.npy'))
times = np.load(os.path.join(trial_path, 'times.npy'))
sorter = -1 if trial['win_ID'] != ids[0] else 1
### event times --> BORIS behavior
contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames = \
load_and_converete_boris_events(trial_path, trial['recording'], sr=20_000)
### communication
if not os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')):
continue
chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy'))
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]]
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]]
### collect for correlations ####
# chirps
if trial_mask[index]:
lose_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[1]))
lose_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[1]))
lose_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[1]))
lose_chrips_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[1]))
lose_chrips_centered_on_win_chirp.append(event_centered_times(chirp_times[0], chirp_times[1]))
lose_chirps_centered_on_lose_rises.append(event_centered_times(rise_times[1], chirp_times[1]))
lose_chirp_count.append(len(chirp_times[1]))
win_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[0]))
win_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[0]))
win_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[0]))
win_chrips_centered_on_lose_rises.append(event_centered_times(rise_times[1], chirp_times[0]))
win_chrips_centered_on_lose_chirp.append(event_centered_times(chirp_times[1], chirp_times[0]))
win_chirps_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[0]))
win_chirp_count.append(len(chirp_times[0]))
else:
lose_chrips_centered_on_ag_off_t.append(np.array([]))
lose_chrips_centered_on_ag_on_t.append(np.array([]))
lose_chrips_centered_on_contact_t.append(np.array([]))
lose_chrips_centered_on_win_rises.append(np.array([]))
lose_chrips_centered_on_win_chirp.append(np.array([]))
lose_chirps_centered_on_lose_rises.append(np.array([]))
lose_chirp_count.append(np.nan)
win_chrips_centered_on_ag_off_t.append(np.array([]))
win_chrips_centered_on_ag_on_t.append(np.array([]))
win_chrips_centered_on_contact_t.append(np.array([]))
win_chrips_centered_on_lose_rises.append(np.array([]))
win_chrips_centered_on_lose_chirp.append(np.array([]))
win_chirp_count.append(np.nan)
# rises
lose_rises_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], rise_times[1]))
lose_rises_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], rise_times[1]))
lose_rises_centered_on_contact_t.append(event_centered_times(contact_t_GRID, rise_times[1]))
lose_rises_centered_on_win_chirps.append(event_centered_times(chirp_times[0], rise_times[1]))
lose_rises_count.append(len(rise_times[1]))
win_rises_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], rise_times[0]))
win_rises_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], rise_times[0]))
win_rises_centered_on_contact_t.append(event_centered_times(contact_t_GRID, rise_times[0]))
win_rises_centered_on_lose_chirps.append(event_centered_times(chirp_times[1], rise_times[0]))
win_rises_count.append(len(rise_times[0]))
ag_off_centered_on_ag_on.append(event_centered_times(ag_on_off_t_GRID[:, 0], ag_on_off_t_GRID[:, 1]))
chase_count.append(len(ag_on_off_t_GRID))
contact_count.append(len(contact_t_GRID))
sex_win.append(trial['sex_win'])
sex_lose.append(trial['sex_lose'])
sex_win = np.array(sex_win)
sex_lose = np.array(sex_lose)
# embed()
# quit()
max_dt = 30
conv_t_dt = 0.5
jack_pct = 0.9
conv_t = cp.arange(-max_dt, max_dt+conv_t_dt, conv_t_dt)
conv_t_numpy = cp.asnumpy(conv_t)
# embed()
# quit()
# for centered_times, event_counts, title in \
# [[lose_chrips_centered_on_ag_off_t, lose_chirp_count, r'chirp$_{lose}$ on chase$_{off}$'],
# [lose_chrips_centered_on_ag_on_t, lose_chirp_count, r'chirp$_{lose}$ on chase$_{on}$'],
# [lose_chrips_centered_on_contact_t, lose_chirp_count, r'chirp$_{lose}$ on contact'],
# [lose_chrips_centered_on_win_rises, lose_chirp_count, r'chirp$_{lose}$ on rise$_{win}$'],
# [lose_chrips_centered_on_win_chirp, lose_chirp_count, r'chirp$_{lose}$ on chirp$_{win}$'],
# [lose_chirps_centered_on_lose_rises, lose_chirp_count, r'chirp$_{lose}$ on rises$_{lose}$'],
#
# [win_chrips_centered_on_ag_off_t, win_chirp_count, r'chirp$_{win}$ on chase$_{off}$'],
# [win_chrips_centered_on_ag_on_t, win_chirp_count, r'chirp$_{win}$ on chase$_{on}$'],
# [win_chrips_centered_on_contact_t, win_chirp_count, r'chirp$_{win}$ on contact'],
# [win_chrips_centered_on_lose_rises, win_chirp_count, r'chirp$_{win}$ on rise$_{lose}$'],
# [win_chrips_centered_on_lose_chirp, win_chirp_count, r'chirp$_{win}$ on chirp$_{lose}$'],
# [win_chirps_centered_on_win_rises, win_chirp_count, r'chirp$_{win}$ on rises$_{win}$'],
#
# [lose_rises_centered_on_ag_off_t, lose_rises_count, r'rise$_{lose}$ on chase$_{off}$'],
# [lose_rises_centered_on_ag_on_t, lose_rises_count, r'rise$_{lose}$ on chase$_{on}$'],
# [lose_rises_centered_on_contact_t, lose_rises_count, r'rise$_{lose}$ on contact'],
# [lose_rises_centered_on_win_chirps, lose_rises_count, r'rise$_{lose}$ on chirp$_{win}$'],
#
# [win_rises_centered_on_ag_off_t, win_rises_count, r'rise$_{win}$ on chase$_{off}$'],
# [win_rises_centered_on_ag_on_t, win_rises_count, r'rise$_{win}$ on chase$_{on}$'],
# [win_rises_centered_on_contact_t, win_rises_count, r'rise$_{win}$ on contact'],
# [win_rises_centered_on_lose_chirps, win_rises_count, r'rise$_{win}$ on chirp$_{lose}$'],
#
# [ag_off_centered_on_ag_on, chase_count, r'chase$_{off}$ on chase$_{on}$']]:
for centered_times, event_counts, title in \
[[lose_chrips_centered_on_ag_off_t, chase_count, r'chirp$_{lose}$ on chase$_{off}$'],
[lose_chrips_centered_on_ag_on_t, chase_count, r'chirp$_{lose}$ on chase$_{on}$'],
[lose_chrips_centered_on_contact_t, contact_count, r'chirp$_{lose}$ on contact'],
[lose_chrips_centered_on_win_rises, win_rises_count, r'chirp$_{lose}$ on rise$_{win}$'],
[lose_chrips_centered_on_win_chirp, win_chirp_count, r'chirp$_{lose}$ on chirp$_{win}$'],
[lose_chirps_centered_on_lose_rises, lose_rises_count, r'chirp$_{lose}$ on rises$_{lose}$'],
[win_chrips_centered_on_ag_off_t, chase_count, r'chirp$_{win}$ on chase$_{off}$'],
[win_chrips_centered_on_ag_on_t, chase_count, r'chirp$_{win}$ on chase$_{on}$'],
[win_chrips_centered_on_contact_t, contact_count, r'chirp$_{win}$ on contact'],
[win_chrips_centered_on_lose_rises, lose_rises_count, r'chirp$_{win}$ on rise$_{lose}$'],
[win_chrips_centered_on_lose_chirp, lose_chirp_count, r'chirp$_{win}$ on chirp$_{lose}$'],
[win_chirps_centered_on_win_rises, win_rises_count, r'chirp$_{win}$ on rises$_{win}$'],
[lose_rises_centered_on_ag_off_t, chase_count, r'rise$_{lose}$ on chase$_{off}$'],
[lose_rises_centered_on_ag_on_t, chase_count, r'rise$_{lose}$ on chase$_{on}$'],
[lose_rises_centered_on_contact_t, contact_count, r'rise$_{lose}$ on contact'],
[lose_rises_centered_on_win_chirps, win_chirp_count, r'rise$_{lose}$ on chirp$_{win}$'],
[win_rises_centered_on_ag_off_t, chase_count, r'rise$_{win}$ on chase$_{off}$'],
[win_rises_centered_on_ag_on_t, chase_count, r'rise$_{win}$ on chase$_{on}$'],
[win_rises_centered_on_contact_t, contact_count, r'rise$_{win}$ on contact'],
[win_rises_centered_on_lose_chirps, lose_chirp_count, r'rise$_{win}$ on chirp$_{lose}$'],
[ag_off_centered_on_ag_on, chase_count, r'chase$_{off}$ on chase$_{on}$']]:
save_str = title.replace('$', '').replace('{', '').replace('}', '').replace(' ', '_')
###########################################################################################################
### by pairing ###
centered_times_pairing = []
for sex_w, sex_l in itertools.product(['m', 'f'], repeat=2):
centered_times_pairing.append([])
for i in range(len(centered_times)):
if sex_w == sex_win[i] and sex_l == sex_lose[i]:
centered_times_pairing[-1].append(centered_times[i])
event_counts_pairings = [np.nansum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'm')]),
np.nansum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'f')]),
np.nansum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'm')]),
np.nansum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'f')])]
color = [male_color, female_color, male_color, female_color]
linestyle = ['-', '--', '--', '-']
perm_p_pairings = []
jk_p_pairings = []
fig = plt.figure(figsize=(20/2.54, 12/2.54))
gs = gridspec.GridSpec(2, 2, left=0.1, bottom=0.1, right=0.95, top=0.9)
ax = []
ax.append(fig.add_subplot(gs[0, 0]))
ax.append(fig.add_subplot(gs[0, 1], sharey=ax[0]))
ax.append(fig.add_subplot(gs[1, 0], sharex=ax[0]))
ax.append(fig.add_subplot(gs[1, 1], sharey=ax[2], sharex=ax[1]))
for enu, (centered_times_p, event_count_p) in enumerate(zip(centered_times_pairing, event_counts_pairings)):
boot_kde = permutation_kde(np.hstack(centered_times_p), conv_t, kernal_w=1, kernal_h=1)
jk_kde = jackknife_kde(np.hstack(centered_times_p), conv_t, jack_pct=jack_pct, kernal_w=1, kernal_h=1)
perm_p1, perm_p50, perm_p99 = np.percentile(boot_kde, (1, 50, 99), axis=0)
perm_p_pairings.append([perm_p1, perm_p50, perm_p99])
jk_p1, jk_p50, jk_p99 = np.percentile(jk_kde, (1, 50, 99), axis=0)
jk_p_pairings.append([jk_p1, jk_p50, jk_p99])
ax[enu].fill_between(conv_t_numpy, perm_p1 / event_count_p, perm_p99 / event_count_p, color='cornflowerblue', alpha=.8)
ax[enu].plot(conv_t_numpy, perm_p50 / event_count_p, color='dodgerblue', alpha=1, lw=3)
ax[enu].fill_between(conv_t_numpy, jk_p1 / event_count_p / jack_pct, jk_p99 / event_count_p / jack_pct, color=color[enu], alpha=.8)
ax[enu].plot(conv_t_numpy, jk_p50 / event_count_p / jack_pct, color=color[enu], alpha=1, lw=3, linestyle=linestyle[enu])
ax_m = ax[enu].twinx()
counter = 0
for enu2, centered_events in enumerate(centered_times_p):
Cevents = centered_events[np.abs(centered_events) <= max_dt]
if len(Cevents) != 0:
ax_m.plot(Cevents, np.ones(len(Cevents)) * counter, '|', markersize=8, color='k', alpha=.1)
counter += 1
ax_m.set_yticks([])
ax[enu].set_xlim(-max_dt, max_dt)
ax[enu].tick_params(labelsize=10)
plt.setp(ax[1].get_yticklabels(), visible=False)
plt.setp(ax[3].get_yticklabels(), visible=False)
plt.setp(ax[0].get_xticklabels(), visible=False)
plt.setp(ax[1].get_xticklabels(), visible=False)
ax[2].set_xlabel('time [s]', fontsize=12)
ax[3].set_xlabel('time [s]', fontsize=12)
ax[0].set_ylabel('event rate [Hz]', fontsize=12)
ax[2].set_ylabel('event rate [Hz]', fontsize=12)
fig.suptitle(title)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr', f'{save_str}_by_sexes.png'), dpi=300)
plt.close()
###########################################################################################################
### all pairings ###
boot_kde = permutation_kde(np.hstack(centered_times), conv_t, kernal_w=1, kernal_h=1)
jk_kde = jackknife_kde(np.hstack(centered_times), conv_t, jack_pct=jack_pct, kernal_w=1, kernal_h=1)
perm_p1, perm_p50, perm_p99 = np.percentile(boot_kde, (1, 50, 99), axis=0)
jk_p1, jk_p50, jk_p99 = np.percentile(jk_kde, (1, 50, 99), axis=0)
fig = plt.figure(figsize=(20/2.54, 12/2.54))
gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.95, top=0.95)
ax = fig.add_subplot(gs[0, 0])
# ax.fill_between(conv_t_numpy, perm_p1/len(np.hstack(centered_times)), perm_p99/len(np.hstack(centered_times)), color='cornflowerblue', alpha=.8)
# ax.plot(conv_t_numpy, perm_p50/len(np.hstack(centered_times)), color='dodgerblue', alpha=1, lw=3)
#
# ax.fill_between(conv_t_numpy, jk_p1/len(np.hstack(centered_times))/jack_pct, jk_p99/len(np.hstack(centered_times))/jack_pct, color='tab:red', alpha=.8)
# ax.plot(conv_t_numpy, jk_p50/len(np.hstack(centered_times))/jack_pct, color='firebrick', alpha=1, lw=3)
ax.fill_between(conv_t_numpy, perm_p1/np.nansum(event_counts), perm_p99/np.nansum(event_counts), color='cornflowerblue', alpha=.8)
ax.plot(conv_t_numpy, perm_p50/np.nansum(event_counts), color='dodgerblue', alpha=1, lw=3)
ax.fill_between(conv_t_numpy, jk_p1/np.nansum(event_counts)/jack_pct, jk_p99/np.nansum(event_counts)/jack_pct, color='tab:red', alpha=.8)
ax.plot(conv_t_numpy, jk_p50/np.nansum(event_counts)/jack_pct, color='firebrick', alpha=1, lw=3)
ax_m = ax.twinx()
for enu, centered_events in enumerate(centered_times):
Cevents = centered_events[np.abs(centered_events) <= max_dt]
if len(Cevents) != 0:
ax_m.plot(Cevents, np.ones(len(Cevents)) * counter, '|', markersize=8, color='k', alpha=.1)
counter += 1
# ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1)
ax_m.set_yticks([])
ax.set_xlabel('time [s]', fontsize=12)
ax.set_ylabel('event rate [Hz]', fontsize=12)
ax.set_title(title)
ax.set_xlim(-max_dt, max_dt)
ax.tick_params(labelsize=10)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr', f'{save_str}.png'), dpi=300)
plt.close()
if __name__ == '__main__':
main(sys.argv[1])

165
code/event_videos.py Normal file
View File

@@ -0,0 +1,165 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import sys
import cv2
import glob
import argparse
from IPython import embed
from tqdm import tqdm
from thunderfish.powerspectrum import decibel
def main(folder, dt):
video_path = glob.glob(os.path.join(folder, '2022*.mp4'))[0]
create_video_path = os.path.join(folder, 'rise_video')
if not os.path.exists(create_video_path):
os.mkdir(create_video_path)
video = cv2.VideoCapture(video_path) # was 'cap'
# fish_freqs = np.load(os.path.join(folder, 'analysis', 'fish_freq_interp.npy'))
fish_freqs = np.load(os.path.join(folder, 'analysis', 'fish_freq.npy'))
rise_idx = np.load(os.path.join(folder, 'analysis', 'rise_idx.npy'))
frame_times = np.load(os.path.join(folder, 'analysis', 'frame_times.npy'))
times = np.load(os.path.join(folder, 'times.npy'))
fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy'))
fill_times = np.load(os.path.join(folder, 'fill_times.npy'))
fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy'))
fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r',
shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F')
#######################################
for fish_nr in np.arange(2)[::-1]:
for idx_oi in tqdm(np.array(rise_idx[fish_nr][~np.isnan(rise_idx[fish_nr])], dtype=int)):
time_oi = times[idx_oi]
HH = int((time_oi / 3600) // 1)
MM = int((time_oi - HH * 3600) // 60)
SS = int(time_oi - HH * 3600 - MM * 60)
frames_oi = np.arange(len(frame_times))[np.abs(frame_times - time_oi) <= dt]
idxs_oi = np.arange(len(times))[np.abs(times - time_oi) <= dt*3]
fig = plt.figure(figsize=(16*2/2.54, 9*2/2.54))
gs = gridspec.GridSpec(6, 2, left=0.075, bottom=0.05, right=1, top=0.95, width_ratios=(1.5, 3), hspace=.3, wspace=0.05)
ax = []
ax.append(fig.add_subplot(gs[:, 1]))
ax.append(fig.add_subplot(gs[1:3, 0]))
ax.append(fig.add_subplot(gs[3:5, 0]))
y00, y01 = np.nanmin(fish_freqs[0][idxs_oi]), np.nanmax(fish_freqs[0][idxs_oi])
y10, y11 = np.nanmin(fish_freqs[1][idxs_oi]), np.nanmax(fish_freqs[1][idxs_oi])
if y01 - y00 < 20:
y01 = y00 + 20
if y11 - y10 < 20:
y11 = y10 + 20
freq_span1 = (y01) - (y00)
freq_span2 = (y11) - (y10)
yspan = freq_span1 if freq_span1 > freq_span2 else freq_span2
ax[1].plot(times[idxs_oi] - time_oi, fish_freqs[0][idxs_oi], marker='.', markersize=4, color='darkorange', lw=2, alpha=0.4)
ax[2].plot(times[idxs_oi] - time_oi, fish_freqs[1][idxs_oi], marker='.', markersize=4,color='forestgreen', lw=2, alpha=0.4)
ax[1].plot([0, 0], [y00 - yspan * 0.2, y00 + yspan * 1.3], '--', color='k')
ax[2].plot([0, 0], [y10 - yspan * 0.2, y10 + yspan * 1.3], '--', color='k')
ax[1].set_xticks([-30, -15, 0, 15, 30])
ax[2].set_xticks([-30, -15, 0, 15, 30])
plt.setp(ax[1].get_xticklabels(), visible=False)
# spectrograms
f_mask1 = np.arange(len(fill_freqs))[(fill_freqs >= y00 - yspan * 0.2) & (fill_freqs <= y00 + yspan * 1.3)]
f_mask2 = np.arange(len(fill_freqs))[(fill_freqs >= y10 - yspan * 0.2) & (fill_freqs <= y10 + yspan * 1.3)]
t_mask = np.arange(len(fill_times))[(fill_times >= time_oi-dt*4) & (fill_times <= time_oi+dt*4)]
ax[1].imshow(decibel(fill_spec[f_mask1[0]:f_mask1[-1], t_mask[0]:t_mask[-1]][::-1]),
extent=[-dt*4, dt*4, y00 - yspan * 0.2, y00 + yspan * 1.3],
aspect='auto',vmin = -100, vmax = -50, alpha=0.7, cmap='jet', interpolation='gaussian')
ax[2].imshow(decibel(fill_spec[f_mask2[0]:f_mask2[-1], t_mask[0]:t_mask[-1]][::-1]),
extent=[-dt*4, dt*4, y10 - yspan * 0.2, y10 + yspan * 1.3],
aspect='auto',vmin = -100, vmax = -50, alpha=0.7, cmap='jet', interpolation='gaussian')
ax[1].set_ylim(y00 - yspan * 0.1, y00 + yspan * 1.2)
ax[1].set_xlim(-dt*3, dt*3)
ax[2].set_ylim(y10 - yspan * 0.1, y10 + yspan * 1.2)
ax[2].set_xlim(-dt*3, dt*3)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].tick_params(labelsize=12)
ax[2].tick_params(labelsize=12)
ax[2].set_xlabel('time [s]', fontsize=14)
fig.text(0.02, 0.5, 'frequency [Hz]', fontsize=14, va='center', rotation='vertical')
# plt.ion()
for i in tqdm(np.arange(len(frames_oi))):
break
video.set(cv2.CAP_PROP_POS_FRAMES, int(frames_oi[i]))
ret, frame = video.read()
if i == 250:
dot, = ax[0].plot(0.05, 0.95, 'o', color='firebrick', transform = ax[0].transAxes, markersize=20)
if i == 280:
dot.remove()
if i == 0:
img = ax[0].imshow(frame)
line1, = ax[1].plot([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
[y00 - yspan * 0.15, y00 + yspan * 1.3],
color='k', lw=1)
line2, = ax[2].plot([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
[y10 - yspan * 0.15, y10 + yspan * 1.3],
color='k', lw=1)
else:
img.set_data(frame)
line1.set_data([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
[y00 - yspan * 0.15, y00 + yspan * 1.3])
line2.set_data([frame_times[frames_oi[i]] - time_oi, frame_times[frames_oi[i]] - time_oi],
[y10 - yspan * 0.15, y10 + yspan * 1.3])
label = (os.path.join(create_video_path, 'frame%4.f.jpg' % len(glob.glob(os.path.join(create_video_path, '*.jpg'))))).replace(' ', '0')
plt.savefig(label, dpi=300)
# plt.pause(0.001)
# quit()
win_lose_str = 'lose' if fish_nr == 1 else 'win'
# video_name = ("./rise_video/%s_%2.f:%2.f:%2.f.mp4" % (win_lose_str, HH, MM, SS)).replace(' ', '0')
# command = "ffmpeg -r 25 -i './rise_video/frame%4d.jpg' -vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' -vcodec libx264 -y -an"
video_name = os.path.join(create_video_path, ("%s_%2.f:%2.f:%2.f.mp4" % (win_lose_str, HH, MM, SS)).replace(' ', '0'))
command1 = "ffmpeg -r 25 -i"
frames_path = '"%s"' % os.path.join(create_video_path, "frame%4d.jpg")
command2 = "-vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' -vcodec libx264 -y -an"
os.system(' '.join([command1, frames_path, command2, video_name]))
os.system(' '.join(['rm', os.path.join(create_video_path, '*.jpg')]))
# os.system(' '.join([command, video_name]))
# os.system('rm ./rise_video/*.jpg')
plt.close()
embed()
quit()
###############################
fig, ax = plt.subplots()
for i, c in enumerate(['firebrick', 'cornflowerblue']):
ax.plot(times, fish_freqs[i], marker='.', color=c)
r_idx = np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int)
ax.plot(times[r_idx], fish_freqs[i][r_idx], 'o', color='k')
pass
##############################
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate videos around events.')
parser.add_argument('file', type=str, help='folder/dataset to generate videos from.')
parser.add_argument('-t', type=float, default=10, help='video duration before and after event.')
# parser.add_argument("-c", action="store_true", help="check if LED pos is correct")
# parser.add_argument('-x', type=int, nargs=2, default=[1272, 1282], help='x-borders of LED detect area (in pixels)')
# parser.add_argument('-y', type=int, nargs=2, default=[1500, 1516], help='y-borders of LED area (in pixels)')
args = parser.parse_args()
main(args.file, args.t)

Binary file not shown.

After

Width:  |  Height:  |  Size: 300 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 299 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 262 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 301 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 280 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 319 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 280 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 312 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 284 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 282 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 279 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 329 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 271 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 300 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 293 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 282 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 270 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 281 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 279 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 257 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 256 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 331 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 321 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 309 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 355 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 261 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 344 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 318 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 326 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 316 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 250 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 287 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 271 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 293 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 272 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 302 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 324 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 355 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 433 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 376 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 379 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 322 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 308 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 332 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 314 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 361 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 375 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 236 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 222 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 211 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 224 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 205 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 206 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 184 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 184 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 226 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 224 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 173 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 227 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 180 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 222 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 216 KiB

Some files were not shown because too many files have changed in this diff Show More