import os import sys import argparse import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import numpy as np import pandas as pd from thunderfish.eventdetection import detect_peaks from IPython import embed class Trial(object): def __init__(self, folder, base_path, meta, fish_count): self._isValid = False self.base_path = base_path self.folder = folder self.meta = meta self.fish_count = fish_count self.light_sec = 3 * 60 * 60 self.ids = None self.fish_freq = None self.fish_freq_interp = None self.fish_freq_val = None self.baseline_freq_times = None self.baseline_freqs = None self.rise_idxs = [] self.rise_size = [] self.fish_sign = None self.fish_sign_interp = None self.winner = None self.loser = None if os.path.exists(os.path.join(self.base_path, self.folder, 'fund_v.npy')): self.load() def __repr__(self): return f'Trial(Date={self.folder}, winner={self.winner})' # return self.folder def load(self): self.fund_v = np.load(os.path.join(self.base_path, self.folder, 'fund_v.npy')) self.idx_v = np.load(os.path.join(self.base_path, self.folder, 'idx_v.npy')) self.times = np.load(os.path.join(self.base_path, self.folder, 'times.npy')) self.ident_v = np.load(os.path.join(self.base_path, self.folder, 'ident_v.npy')) self.sign_v = np.load(os.path.join(self.base_path, self.folder, 'sign_v.npy')) self.ids = np.unique(self.ident_v[~np.isnan(self.ident_v)]) if len(self.ids) == self.fish_count: self.isValid = True def reshape_and_interpolate(self): self.fish_freq = np.full((self.fish_count, len(self.times)), np.nan) self.fish_sign = np.full((self.fish_count, len(self.times), self.sign_v.shape[1]), np.nan) for enu, id in enumerate(self.ids): self.fish_freq[enu][self.idx_v[self.ident_v == id]] = self.fund_v[self.ident_v == id] self.fish_sign[enu][self.idx_v[self.ident_v == id]] = self.sign_v[self.ident_v == id] self.fish_freq_interp = np.full(self.fish_freq.shape, np.nan) self.fish_sign_interp = np.full(self.fish_sign.shape, np.nan) for enu, id in enumerate(self.ids): i0, i1 = self.idx_v[self.ident_v == id][0], self.idx_v[self.ident_v == id][-1] # self.fish_freq_interp[enu, i0:i1+1] = np.interp(self.times[i0:i1+1], # self.times[self.idx_v[self.ident_v == id]], # self.fish_freq[enu][~np.isnan(self.fish_freq[enu])]) self.fish_freq_interp[enu, i0:i1+1] = np.interp(self.times[i0:i1+1], self.times[self.idx_v[self.ident_v == id]], self.fund_v[self.ident_v == id]) # help_sign_v = list(map(lambda x: np.interp(self.times[i0:i1+1], self.times[self.idx_v[self.ident_v == id]], x), # self.fish_sign[enu][~np.isnan(self.fish_freq[enu])].T)) help_sign_v = list(map(lambda x: np.interp(self.times[i0:i1+1], self.times[self.idx_v[self.ident_v == id]], x), self.sign_v[self.ident_v == id].T)) self.fish_sign_interp[enu, i0:i1+1] = np.array(help_sign_v).T def baseline_freq(self, bw = 300): bins = np.arange(-bw / 2, self.times[-1] + bw / 2, bw) self.baseline_freq_times = np.array(bins[:-1] + (bins[1] - bins[0])/2) self.baseline_freqs = np.full((2, len(self.baseline_freq_times)), np.nan) self.pct95_freqs = np.full((2, len(self.baseline_freq_times)), np.nan) for enu, id in enumerate(self.ids): for i in range(len(bins) - 1): Cf = self.fish_freq[enu][(self.times > bins[i]) & (self.times <= bins[i + 1])] if len(Cf) == 0: continue else: self.baseline_freqs[enu][i] = np.nanpercentile(Cf, 5) self.pct95_freqs[enu][i] = np.nanpercentile(Cf, 75) self.fish_freq_val = [np.nanmean(x[self.baseline_freq_times > self.light_sec]) for x in self.baseline_freqs] def winner_detection(self): day_mask = self.times > self.light_sec day_idxs = np.arange(len(self.times))[day_mask] shelter_power = np.empty((2, len(day_idxs))) for enu, id in enumerate(self.ids): shelter_power[enu] = self.fish_sign_interp[enu][day_idxs, -1] mean_shelter_power = np.nanmean(shelter_power, axis=1) self.winner = 1 if mean_shelter_power[1] > mean_shelter_power[0] else 0 self.loser = 0 if self.winner == 1 else 1 def rise_detection(self, rise_th): def check_rises_size(peak): peak_f = self.fish_freq[i][peak] peak_t = self.times[peak] closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(self.baseline_freq_times - x)), peak_t)) closest_baseline_freq = self.baseline_freqs[i][closest_baseline_idx] rise_size = peak_f - closest_baseline_freq return rise_size def correct_rise_idx(rise_peak_idx): rise_dt = np.diff(self.times[rise_peak_idx]) rise_dt[rise_dt >= 10] = 10 rise_dt[rise_dt < 10] = rise_dt[rise_dt < 10] - 1 rise_dt = np.append(np.array([10]), rise_dt) freq_slope = np.full(np.shape(self.fish_freq)[1], np.nan) non_nan_idx = np.arange(len(freq_slope))[~np.isnan(self.fish_freq[i])] freq_slope[non_nan_idx[1:]] = np.diff(self.fish_freq[i][~np.isnan(self.fish_freq[i])]) corrected_rise_idxs = [] for enu, r_idx in enumerate(rise_peak_idx): mask = np.arange(len(freq_slope))[(self.times <= self.times[r_idx]) & (self.times > self.times[r_idx] - rise_dt[enu]) & (~np.isnan(freq_slope))] if len(mask) == 0: corrected_rise_idxs.append(np.nan) else: corrected_rise_idxs.append(mask[np.argmax(freq_slope[mask])]) corrected_rise_idxs = np.array(corrected_rise_idxs) return corrected_rise_idxs for i in range(len(self.fish_freq)): rise_peak_idx, trough = detect_peaks(self.fish_freq[i][~np.isnan(self.fish_freq[i])], rise_th) non_nan_idx = np.arange(len(self.fish_freq[i]))[~np.isnan(self.fish_freq[i])] rise_peak_idx, trough = non_nan_idx[rise_peak_idx], non_nan_idx[trough] rise_size = check_rises_size(rise_peak_idx) rise_idx = correct_rise_idx(rise_peak_idx) # print(np.min(np.diff(self.times[rise_peak_idx]))) self.rise_idxs.append(np.array(rise_idx[(rise_size >= rise_th) & (~np.isnan(rise_idx))], dtype=int)) self.rise_size.append(rise_size[(rise_size >= rise_th) & (~np.isnan(rise_idx))]) def update_meta(self): entries = self.meta.index.tolist() if self. folder not in entries: self.meta.loc[self.folder] = ['' for _ in self.meta.columns] self.meta.loc[self.folder, 'Win_ID'] = self.ids[self.winner] self.meta.loc[self.folder, 'Lose_ID'] = self.ids[self.loser] self.meta.loc[self.folder, 'Win_EODf'] = self.fish_freq_val[self.winner] self.meta.loc[self.folder, 'Lose_EODf'] = self.fish_freq_val[self.loser] self.meta.loc[self.folder, 'Win_rise_c'] = len(self.rise_idxs[self.winner]) self.meta.loc[self.folder, 'Lose_rise_c'] = len(self.rise_idxs[self.loser]) self.meta.loc[self.folder, 'light_sec'] = self.light_sec self.meta.to_csv(os.path.join(self.base_path, 'meta.csv'), sep =',') def ilustrate(self): fig = plt.figure(figsize=(20/2.54, 12/2.54)) gs = gridspec.GridSpec(1, 1, left = 0.1, bottom = 0.1, right = 0.95, top = 0.95) ax = fig.add_subplot(gs[0, 0]) for enu, id in enumerate(self.ids): c = 'firebrick' if self.winner == enu else 'forestgreen' ax.plot(self.times/3600, self.fish_freq[enu], marker='.', color=c, zorder=1) ax.plot(self.times[np.isnan(self.fish_freq[enu])]/3600, self.fish_freq_interp[enu][np.isnan(self.fish_freq[enu])], '.', zorder=1, color=c, alpha=0.25) ax.plot(self.baseline_freq_times/3600, self.baseline_freqs[enu], '--', color='k', zorder=2) ax.plot(self.baseline_freq_times/3600, self.pct95_freqs[enu], '--', color='k', zorder=2) ax.plot(self.times[self.rise_idxs[enu]]/3600, self.fish_freq_interp[enu][self.rise_idxs[enu]], 'o', color='k') win_str = '(W)' if self.winner == enu else '' ax.text(self.times[-1]/3600, self.fish_freq_val[enu]-10, '%.0f' % id + win_str, va ='center', ha='right') ax.set_xlim(0, self.times[-1]/3600) freq_range = (np.nanmin(self.fish_freq), np.nanmax(self.fish_freq)) ax.set_ylim(freq_range[0] - 20, freq_range[1] + 10) ax.set_title(self.folder) plt.show() def save(self): saveorder = -1 if self.winner == 1 else 1 if not os.path.exists(os.path.join(self.base_path, self.folder, 'analysis')): os.mkdir(os.path.join(self.base_path, self.folder, 'analysis')) np.save(os.path.join(self.base_path, self.folder, 'analysis', 'ids.npy'), self.ids[::saveorder]) np.save(os.path.join(self.base_path, self.folder, 'analysis', 'fish_freq.npy'), self.fish_freq[::saveorder]) np.save(os.path.join(self.base_path, self.folder, 'analysis', 'fish_freq_interp.npy'), self.fish_freq_interp[::saveorder]) np.save(os.path.join(self.base_path, self.folder, 'analysis', 'baseline_freqs.npy'), self.baseline_freqs[::saveorder]) np.save(os.path.join(self.base_path, self.folder, 'analysis', 'baseline_freq_times.npy'), self.baseline_freq_times[::saveorder]) help_lens = [len(x) for x in self.rise_idxs] rise_idxs_s = np.full((self.fish_count, np.max(help_lens)), np.nan) rise_size_s = np.full((self.fish_count, np.max(help_lens)), np.nan) for i in range(self.fish_count): rise_idxs_s[i][:len(self.rise_idxs[i])] = self.rise_idxs[i] rise_size_s[i][:len(self.rise_size[i])] = self.rise_size[i] np.save(os.path.join(self.base_path, self.folder, 'analysis', 'rise_idx.npy'), rise_idxs_s[::saveorder]) np.save(os.path.join(self.base_path, self.folder, 'analysis', 'rise_size.npy'), rise_size_s[::saveorder]) @property def isValid(self): return self._isValid @isValid.setter def isValid(self, value): print('Trial (%s) is valid' % (self.folder)) self._isValid = value def frame_to_idx(self, event_frames): self.sr = 20000 LED_idx = pd.read_csv(os.path.join(self.folder, 'led_idxs.csv'), sep=',', encoding = "utf-7") led_idx = np.array(LED_idx).T[0] led_frame = np.load(os.path.join(self.folder, 'LED_frames.npy')) led_idx_span = led_idx[-1] - led_idx[0] led_frame_span = led_frame[-1] - led_frame[0] frames_to_idx = ((event_frames - led_frame[0]) / led_frame_span) * led_idx_span + led_idx[0] event_times = frames_to_idx / self.sr return event_times def main(): parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') parser.add_argument('file', type=str, help='single recording analysis', default='') parser.add_argument('-d', "--dev", action="store_true", help="developer mode; no data saved") # parser.add_argument('-x', type=int, nargs=2, default=[1272, 1282], help='x-borders of LED detect area (in pixels)') # parser.add_argument('-y', type=int, nargs=2, default=[1500, 1516], help='y-borders of LED area (in pixels)') args = parser.parse_args() base_path = None folders = [] for root, dirs, files in os.walk(args.file): for file in files: if file.endswith('.raw'): root = os.path.normpath(root) print(root, file) print(os.path.join(root, file)) folders.append(os.path.split(root)[-1]) if not base_path: base_path = os.path.split(root)[0] folders = sorted(folders) if os.path.exists(os.path.join(base_path, 'meta.csv')) and not args.dev: meta = pd.read_csv(os.path.join(base_path, 'meta.csv'), sep=',', index_col=0, encoding = "utf-7") else: meta = None # embed() # if args.f == '': # folders = os.listdir(args.f) # folders = [x for x in folders if not '.' in x] # else: # folders= [os.path.split(os.path.normpath(args.f))[-1]] # folders = sorted(folders) trials = [] for folder in folders: trial = Trial(folder, base_path, meta, fish_count=2) if not trial.isValid: continue trial.reshape_and_interpolate() trial.winner_detection() trial.baseline_freq(bw=300) # ToDo: q10 corrected EODfs trial.rise_detection(rise_th=5) if meta is not None: if not args.dev: trial.update_meta() if not args.dev: trial.save() trial.ilustrate() trials.append(trial) # meta.loc[folder, 'Fish1_ID'] = 1 # meta.to_csv('') if __name__ == '__main__': main()