competition_experiments/code/trail_analysis.py

321 lines
14 KiB
Python

import os
import sys
import argparse
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
from thunderfish.eventdetection import detect_peaks
from IPython import embed
class Trial(object):
def __init__(self, folder, base_path, meta, fish_count):
self._isValid = False
self.base_path = base_path
self.folder = folder
self.meta = meta
self.fish_count = fish_count
self.light_sec = 3 * 60 * 60
self.ids = None
self.fish_freq = None
self.fish_freq_interp = None
self.fish_freq_val = None
self.baseline_freq_times = None
self.baseline_freqs = None
self.rise_idxs = []
self.rise_size = []
self.fish_sign = None
self.fish_sign_interp = None
self.winner = None
self.loser = None
self.mean_shelter_power = None
if os.path.exists(os.path.join(self.base_path, self.folder, 'fund_v.npy')):
self.load()
def __repr__(self):
return f'Trial(Date={self.folder}, winner={self.winner})'
# return self.folder
def load(self):
self.fund_v = np.load(os.path.join(self.base_path, self.folder, 'fund_v.npy'))
self.idx_v = np.load(os.path.join(self.base_path, self.folder, 'idx_v.npy'))
self.times = np.load(os.path.join(self.base_path, self.folder, 'times.npy'))
self.ident_v = np.load(os.path.join(self.base_path, self.folder, 'ident_v.npy'))
self.sign_v = np.load(os.path.join(self.base_path, self.folder, 'sign_v.npy'))
self.ids = np.unique(self.ident_v[~np.isnan(self.ident_v)])
if len(self.ids) == self.fish_count:
self.isValid = True
def reshape_and_interpolate(self):
self.fish_freq = np.full((self.fish_count, len(self.times)), np.nan)
self.fish_sign = np.full((self.fish_count, len(self.times), self.sign_v.shape[1]), np.nan)
for enu, id in enumerate(self.ids):
self.fish_freq[enu][self.idx_v[self.ident_v == id]] = self.fund_v[self.ident_v == id]
self.fish_sign[enu][self.idx_v[self.ident_v == id]] = self.sign_v[self.ident_v == id]
self.fish_freq_interp = np.full(self.fish_freq.shape, np.nan)
self.fish_sign_interp = np.full(self.fish_sign.shape, np.nan)
for enu, id in enumerate(self.ids):
i0, i1 = self.idx_v[self.ident_v == id][0], self.idx_v[self.ident_v == id][-1]
# self.fish_freq_interp[enu, i0:i1+1] = np.interp(self.times[i0:i1+1],
# self.times[self.idx_v[self.ident_v == id]],
# self.fish_freq[enu][~np.isnan(self.fish_freq[enu])])
self.fish_freq_interp[enu, i0:i1+1] = np.interp(self.times[i0:i1+1],
self.times[self.idx_v[self.ident_v == id]],
self.fund_v[self.ident_v == id])
# help_sign_v = list(map(lambda x: np.interp(self.times[i0:i1+1], self.times[self.idx_v[self.ident_v == id]], x),
# self.fish_sign[enu][~np.isnan(self.fish_freq[enu])].T))
help_sign_v = list(map(lambda x: np.interp(self.times[i0:i1+1], self.times[self.idx_v[self.ident_v == id]], x),
self.sign_v[self.ident_v == id].T))
self.fish_sign_interp[enu, i0:i1+1] = np.array(help_sign_v).T
def baseline_freq(self, bw = 300):
bins = np.arange(-bw / 2, self.times[-1] + bw / 2, bw)
self.baseline_freq_times = np.array(bins[:-1] + (bins[1] - bins[0])/2)
self.baseline_freqs = np.full((2, len(self.baseline_freq_times)), np.nan)
self.pct95_freqs = np.full((2, len(self.baseline_freq_times)), np.nan)
for enu, id in enumerate(self.ids):
for i in range(len(bins) - 1):
Cf = self.fish_freq[enu][(self.times > bins[i]) & (self.times <= bins[i + 1])]
if len(Cf) == 0:
continue
else:
self.baseline_freqs[enu][i] = np.nanpercentile(Cf, 5)
self.pct95_freqs[enu][i] = np.nanpercentile(Cf, 75)
self.fish_freq_val = [np.nanmean(x[self.baseline_freq_times > self.light_sec]) for x in self.baseline_freqs]
def winner_detection(self):
day_mask = self.times > self.light_sec
day_idxs = np.arange(len(self.times))[day_mask]
shelter_power = np.empty((2, len(day_idxs)))
for enu, id in enumerate(self.ids):
shelter_power[enu] = self.fish_sign_interp[enu][day_idxs, -1]
self.mean_shelter_power = np.nanmean(shelter_power, axis=1)
self.winner = 1 if self.mean_shelter_power[1] > self.mean_shelter_power[0] else 0
self.loser = 0 if self.winner == 1 else 1
def rise_detection(self, rise_th):
def check_rises_size(peak):
peak_f = self.fish_freq[i][peak]
peak_t = self.times[peak]
closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(self.baseline_freq_times - x)), peak_t))
closest_baseline_freq = self.baseline_freqs[i][closest_baseline_idx]
rise_size = peak_f - closest_baseline_freq
return rise_size
def correct_rise_idx(rise_peak_idx):
rise_dt = np.diff(self.times[rise_peak_idx])
rise_dt[rise_dt >= 10] = 10
rise_dt[rise_dt < 10] = rise_dt[rise_dt < 10] - 1
rise_dt = np.append(np.array([10]), rise_dt)
freq_slope = np.full(np.shape(self.fish_freq)[1], np.nan)
non_nan_idx = np.arange(len(freq_slope))[~np.isnan(self.fish_freq[i])]
freq_slope[non_nan_idx[1:]] = np.diff(self.fish_freq[i][~np.isnan(self.fish_freq[i])])
corrected_rise_idxs = []
for enu, r_idx in enumerate(rise_peak_idx):
mask = np.arange(len(freq_slope))[(self.times <= self.times[r_idx]) &
(self.times > self.times[r_idx] - rise_dt[enu]) &
(~np.isnan(freq_slope))]
if len(mask) == 0:
corrected_rise_idxs.append(np.nan)
else:
corrected_rise_idxs.append(mask[np.argmax(freq_slope[mask])])
corrected_rise_idxs = np.array(corrected_rise_idxs)
return corrected_rise_idxs
for i in range(len(self.fish_freq)):
rise_peak_idx, trough = detect_peaks(self.fish_freq[i][~np.isnan(self.fish_freq[i])], rise_th)
non_nan_idx = np.arange(len(self.fish_freq[i]))[~np.isnan(self.fish_freq[i])]
rise_peak_idx, trough = non_nan_idx[rise_peak_idx], non_nan_idx[trough]
rise_size = check_rises_size(rise_peak_idx)
rise_idx = correct_rise_idx(rise_peak_idx)
# print(np.min(np.diff(self.times[rise_peak_idx])))
self.rise_idxs.append(np.array(rise_idx[(rise_size >= rise_th) & (~np.isnan(rise_idx))], dtype=int))
self.rise_size.append(rise_size[(rise_size >= rise_th) & (~np.isnan(rise_idx))])
def update_meta(self):
entries = self.meta.index.tolist()
if self. folder not in entries:
self.meta.loc[self.folder] = ['' for _ in self.meta.columns]
self.meta.loc[self.folder, 'Win_ID'] = self.ids[self.winner]
self.meta.loc[self.folder, 'Lose_ID'] = self.ids[self.loser]
self.meta.loc[self.folder, 'Win_EODf'] = self.fish_freq_val[self.winner]
self.meta.loc[self.folder, 'Lose_EODf'] = self.fish_freq_val[self.loser]
self.meta.loc[self.folder, 'Win_rise_c'] = len(self.rise_idxs[self.winner])
self.meta.loc[self.folder, 'Lose_rise_c'] = len(self.rise_idxs[self.loser])
self.meta.loc[self.folder, 'light_sec'] = self.light_sec
self.meta.to_csv(os.path.join(self.base_path, 'meta.csv'), sep =',')
def ilustrate(self):
fig = plt.figure(figsize=(20/2.54, 12/2.54))
gs = gridspec.GridSpec(1, 1, left = 0.1, bottom = 0.1, right = 0.95, top = 0.95)
ax = fig.add_subplot(gs[0, 0])
for enu, id in enumerate(self.ids):
c = 'firebrick' if self.winner == enu else 'forestgreen'
ax.plot(self.times/3600, self.fish_freq[enu], marker='.', color=c, zorder=1, label=f'{self.mean_shelter_power[enu]:.2f}dB')
ax.plot(self.times[np.isnan(self.fish_freq[enu])]/3600, self.fish_freq_interp[enu][np.isnan(self.fish_freq[enu])], '.', zorder=1, color=c, alpha=0.25)
ax.plot(self.baseline_freq_times/3600, self.baseline_freqs[enu], '--', color='k', zorder=2)
ax.plot(self.baseline_freq_times/3600, self.pct95_freqs[enu], '--', color='k', zorder=2)
ax.plot(self.times[self.rise_idxs[enu]]/3600, self.fish_freq_interp[enu][self.rise_idxs[enu]], 'o', color='k')
win_str = '(W)' if self.winner == enu else ''
ax.text(self.times[-1]/3600, self.fish_freq_val[enu]-10, '%.0f' % id + win_str, va ='center', ha='right')
ax.set_xlim(0, self.times[-1]/3600)
freq_range = (np.nanmin(self.fish_freq), np.nanmax(self.fish_freq))
ax.set_ylim(freq_range[0] - 20, freq_range[1] + 10)
ax.legend(loc = 'upper right', bbox_to_anchor=(1, 1))
ax.set_title(self.folder)
plt.show()
def save(self):
saveorder = -1 if self.winner == 1 else 1
if not os.path.exists(os.path.join(self.base_path, self.folder, 'analysis')):
os.mkdir(os.path.join(self.base_path, self.folder, 'analysis'))
np.save(os.path.join(self.base_path, self.folder, 'analysis', 'ids.npy'), self.ids[::saveorder])
np.save(os.path.join(self.base_path, self.folder, 'analysis', 'fish_freq.npy'), self.fish_freq[::saveorder])
np.save(os.path.join(self.base_path, self.folder, 'analysis', 'fish_freq_interp.npy'), self.fish_freq_interp[::saveorder])
np.save(os.path.join(self.base_path, self.folder, 'analysis', 'baseline_freqs.npy'), self.baseline_freqs[::saveorder])
np.save(os.path.join(self.base_path, self.folder, 'analysis', 'baseline_freq_times.npy'), self.baseline_freq_times[::saveorder])
help_lens = [len(x) for x in self.rise_idxs]
rise_idxs_s = np.full((self.fish_count, np.max(help_lens)), np.nan)
rise_size_s = np.full((self.fish_count, np.max(help_lens)), np.nan)
for i in range(self.fish_count):
rise_idxs_s[i][:len(self.rise_idxs[i])] = self.rise_idxs[i]
rise_size_s[i][:len(self.rise_size[i])] = self.rise_size[i]
np.save(os.path.join(self.base_path, self.folder, 'analysis', 'rise_idx.npy'), rise_idxs_s[::saveorder])
np.save(os.path.join(self.base_path, self.folder, 'analysis', 'rise_size.npy'), rise_size_s[::saveorder])
@property
def isValid(self):
return self._isValid
@isValid.setter
def isValid(self, value):
print('Trial (%s) is valid' % (self.folder))
self._isValid = value
def frame_to_idx(self, event_frames):
self.sr = 20000
LED_idx = pd.read_csv(os.path.join(self.folder, 'led_idxs.csv'), sep=',', encoding = "utf-7")
led_idx = np.array(LED_idx).T[0]
led_frame = np.load(os.path.join(self.folder, 'LED_frames.npy'))
led_idx_span = led_idx[-1] - led_idx[0]
led_frame_span = led_frame[-1] - led_frame[0]
frames_to_idx = ((event_frames - led_frame[0]) / led_frame_span) * led_idx_span + led_idx[0]
event_times = frames_to_idx / self.sr
return event_times
def main():
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
parser.add_argument('file', type=str, help='single recording analysis', default='')
parser.add_argument('-d', "--dev", action="store_true", help="developer mode; no data saved")
# parser.add_argument('-x', type=int, nargs=2, default=[1272, 1282], help='x-borders of LED detect area (in pixels)')
# parser.add_argument('-y', type=int, nargs=2, default=[1500, 1516], help='y-borders of LED area (in pixels)')
args = parser.parse_args()
base_path = None
folders = []
for root, dirs, files in os.walk(args.file):
for file in files:
if file.endswith('.raw'):
root = os.path.normpath(root)
print(root, file)
print(os.path.join(root, file))
folders.append(os.path.split(root)[-1])
if not base_path:
base_path = os.path.split(root)[0]
folders = sorted(folders)
if os.path.exists(os.path.join(base_path, 'meta.csv')) and not args.dev:
meta = pd.read_csv(os.path.join(base_path, 'meta.csv'), sep=',', index_col=0, encoding = "utf-7")
else:
meta = None
# embed()
# if args.f == '':
# folders = os.listdir(args.f)
# folders = [x for x in folders if not '.' in x]
# else:
# folders= [os.path.split(os.path.normpath(args.f))[-1]]
# folders = sorted(folders)
trials = []
for folder in folders:
trial = Trial(folder, base_path, meta, fish_count=2)
if not trial.isValid:
continue
trial.reshape_and_interpolate()
trial.winner_detection()
trial.baseline_freq(bw=300)
# ToDo: q10 corrected EODfs
trial.rise_detection(rise_th=5)
if meta is not None:
if not args.dev:
trial.update_meta()
if not args.dev:
trial.save()
trial.ilustrate()
trials.append(trial)
# meta.loc[folder, 'Fish1_ID'] = 1
# meta.to_csv('')
if __name__ == '__main__':
main()