import numpy as np import matplotlib.pyplot as plt import matplotlib.dates as mdates import matplotlib.colors as mcolors import matplotlib.gridspec as gridspec from IPython import embed from scipy import stats import os from IPython import embed import helper_functions as hf from params import * import itertools from statisitic_functions import * if __name__ == '__main__': ################################################################################################################### # parameter and variables # plot params inch = 2.45 save_path = '../../thesis/Figures/Results/' # kernel kernel_size = 120 kernel = np.ones(kernel_size) / kernel_size # counter stl_counter = 0 ax0_counter = 0 # transit: start, stop, dur transit_data = [] # transit: time, trajectory xy_transit = [] # transit: metadata: name, freq transit_meta = [] # average speed speed = [] ################################################################################################################### # load data ################################################################################################################### for day_idx in [0, 1, 2, 3]: filename = sorted(os.listdir('../data/'))[day_idx] aifl = np.load('../data/' + filename + '/aifl2.npy', allow_pickle=True) all_xticks = np.load('../data/' + filename + '/all_xtickses.npy', allow_pickle=True) all_max_ch_means = np.load('../data/' + filename + '/all_max_ch.npy', allow_pickle=True) power_means = np.load('../data/' + filename + '/power_means.npy', allow_pickle=True) all_hms = np.load('../data/' + filename + '/all_hms.npy', allow_pickle=True) freq = np.load('../data/' + filename + '/fish_freq_q10.npy', allow_pickle=True) names = np.load('../data/' + filename + '/fish_species.npy', allow_pickle=True) ############################################################################################################### # lists fish_in_aifl = list(np.unique(np.where(~np.isnan(aifl[:, :, 0]))[1])) # directs = [] flat_x = np.unique(np.hstack(all_xticks)) ############################################################################################################### # analysis of the changes and trajectories for fish_number in range(len(power_means)): if power_means[fish_number] >= -90.0: x_tickses = all_xticks[fish_number] max_ch_mean = all_max_ch_means[fish_number] # smoothing of max channel mean kernel = np.ones(kernel_size) / kernel_size smooth_mcm = np.convolve(max_ch_mean, kernel, 'valid') smooth_x = x_tickses[int(np.ceil(kernel_size / 2)):-int(np.floor(kernel_size / 2) - 1)] # interpolate fish_x = flat_x[np.where(flat_x == x_tickses[0])[0][0]:np.where(flat_x == x_tickses[-1])[0][0] + 1] try: trajectory = np.round(np.interp(fish_x, smooth_x, smooth_mcm)) except: continue # trial duration t_s = datetime.timedelta(np.diff([fish_x[0], fish_x[-1]])[0]) trial_dur = t_s.total_seconds() / 60 # average speed speed.append(np.sum(np.abs(np.diff(trajectory))) / trial_dur) # activity vs. time t = np.array(all_hms[fish_number] / 60 / 60) # swim through first = fish_x[0] last = fish_x[-1] start15 = mdates.date2num(mdates.num2date(first) + datetime.timedelta(seconds=2 * 60)) stop15 = mdates.date2num(mdates.num2date(last) - datetime.timedelta(seconds=2 * 60)) y = trajectory # t = np.roll(t, int(len(np.unique(np.hstack(all_hms)))/2)) if np.any(trajectory[fish_x < start15] >= 13) and np.any(trajectory[fish_x > stop15] <= 2): xy_transit.append(np.array([t, y])) transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx])) transit_meta.append([names[fish_number], freq[fish_number,2]]) elif np.any(trajectory[fish_x < start15] <= 2) and np.any(trajectory[fish_x > stop15] >= 13): xy_transit.append(np.array([t, y])) transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx])) transit_meta.append([names[fish_number], freq[fish_number,2]]) elif np.any(trajectory[fish_x < start15] <= 1) \ and np.any(trajectory[fish_x > stop15] >= 6) \ and np.any(trajectory[fish_x > stop15] <= 7): xy_transit.append(np.array([t, y])) transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx])) transit_meta.append([names[fish_number], freq[fish_number,2]]) elif np.any(trajectory[fish_x < start15] >= 6) \ and np.any(trajectory[fish_x < start15] <= 7) \ and np.any(trajectory[fish_x > stop15] <= 1): xy_transit.append(np.array([t, y])) transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx])) transit_meta.append([names[fish_number], freq[fish_number,2]]) elif np.any(trajectory[fish_x < start15] >= 14) \ and np.any(trajectory[fish_x > stop15] >= 8) \ and np.any(trajectory[fish_x > stop15] <= 9): xy_transit.append(np.array([t, y])) transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx])) transit_meta.append([names[fish_number], freq[fish_number,2]]) elif np.any(trajectory[fish_x < start15] >= 8) \ and np.any(trajectory[fish_x < start15] <= 9) \ and np.any(trajectory[fish_x > stop15] >= 14): xy_transit.append(np.array([t, y])) transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx])) transit_meta.append([names[fish_number], freq[fish_number,2]]) else: continue transit_data = np.array(transit_data) ############################################################################################################### # figure 1: time points of transit fish and trajectories fig = plt.figure(constrained_layout=True, figsize=[15 / inch, 12 / inch]) gs = gridspec.GridSpec(ncols=2, nrows=3, figure=fig, hspace=0.05, wspace=0.0, height_ratios=[1, 1, 1], left=0.1, bottom=0.15, right=0.95, top=0.95) ax0 = fig.add_subplot(gs[0, :]) ax1 = fig.add_subplot(gs[1, 0]) ax2 = fig.add_subplot(gs[1, 1]) ax3 = fig.add_subplot(gs[2, 0]) ax4 = fig.add_subplot(gs[2, 1]) ############################################################################################################### # plotting color_list = [0, 1, 4, 6] for ploti in range(len(transit_data)): if transit_data[ploti, 2] < 3 * 60: # colori = color_list[int(transit_data[ploti, 3])] if transit_meta[ploti][0] == 'Eigenmannia': colori = 0 elif transit_meta[ploti][1] < 750: colori = 1 else: colori = 2 # plot time when fish is there if transit_data[ploti, 0]<12: roll_factor = 12 else: roll_factor = -12 ax0.plot(transit_data[ploti, :2] + roll_factor, [ax0_counter, ax0_counter], color=color_efm[colori], lw=3, label=labels[colori]) ax0_counter += 1 if transit_data[ploti + 1, 3] != transit_data[ploti, 3]: ax0.plot([0, 24], [ax0_counter, ax0_counter], color=color_diffdays[0], lw=1, linestyle='dashed') print(ax0_counter, transit_data[ploti, 3], transit_data[ploti+1, 3]) ax0_counter += 1 # plot trajectories if transit_data[ploti, 0] > 1 and transit_data[ploti, 1] <= 3.5: ax2.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori]) ax2.set_xlim([1, 3.1]) ax2.set_xticks([1,2,3]) ax2.set_xticklabels(['01:00', '02:00', '03:00']) elif transit_data[ploti, 0] > 2.5 and transit_data[ploti, 1] <= 5: ax4.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori]) ax4.set_xlim([3, 5]) ax4.set_xticks([3,4,5]) ax4.set_xticklabels(['03:00', '04:00', '05:00']) elif transit_data[ploti, 0] > 18 and transit_data[ploti, 1] <= 20: ax1.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori]) ax1.set_xlim([18, 20]) ax1.set_xticks([18,19,20]) ax1.set_xticklabels(['18:00', '19:00', '20:00']) elif transit_data[ploti, 0] > 20 and transit_data[ploti, 1] <= 22: ax3.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori]) ax3.set_xlim([20, 22]) ax3.set_xticks([20,21,22]) ax3.set_xticklabels(['20:00', '21:00', '22:00']) ################################################################################################################# # nice axis tagx = [-0.07, -0.17, -0.07, -0.17, -0.07] tagy = [0.95,1.05,1.05,1.05,1.05] for ax_idx, axis in enumerate([ax0,ax1,ax2,ax3,ax4]): axis.text(tagx[ax_idx], tagy[ax_idx], chr(ord('A') + ax_idx), transform=axis.transAxes, fontsize='large') axis.make_nice_ax() axis.invert_yaxis() if axis != ax0: axis.axhline(7.5, xmin=0, xmax=15, color='white', lw=2) # ax.set_yticklabels([1, 3, 5, 7, 14, 16, 18, 20]) axis.set_yticks([0, 1, 2, 3, 4, 5, 6, 7, 7.5, 8, 9, 10, 11, 12, 13, 14, 15]) axis.set_yticklabels([]) if axis == ax1 or axis == ax3: axis.set_yticklabels(['1', '', '', '', '5', '', '', '', 'gap', '', '', '', '12', '', '', '', '16']) axis.set_ylabel('Electrode', fontsize=11) if axis == ax3 or axis == ax4: axis.set_xlabel('Time', fontsize=11) daytagy = [0.1, 0.5, 0.9] for dayi in [0,1,2]: ax0.text(0.02, daytagy[dayi], 'Day '+ str(dayi), transform=ax0.transAxes, fontsize='small', va='center', ha='left') ax0.set_ylabel('Fish') ax0.set_yticks([]) ax0.set_xlim([-0.1, 24.1]) ax0.set_xticks(np.arange(0, 25, 3)) ax0.set_xticklabels(['12:00', '15:00', '18:00', '21:00', '00:00', '03:00', '06:00', '09:00', '12:00']) fig.align_ylabels() # handles, labels = ax0.get_legend_handles_labels() # unique = [(h, l) for lwdl_idx, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:lwdl_idx]] # ax0.legend(*zip(*unique), bbox_to_anchor=(1, 1), loc="upper right", bbox_transform=fig.transFigure, ncol=1) # fig.legend() fig.savefig(save_path + 'transit.pdf') plt.show() ############################################################################################################### # speed c = np.load('../data/all_changes.npy', allow_pickle=True) stl = np.load('../data/stl.npy', allow_pickle=True) speeds = [] for ploti in range(len(transit_data)): if transit_data[ploti, 2] < 3 * 60: speeds.append((np.max(xy_transit[ploti][1])-np.min(xy_transit[ploti][1]))/transit_data[ploti, 2]) min =(np.mean(speeds)-np.std(speeds))*60*12 max =(np.mean(speeds)+np.std(speeds))*60*12/1000 embed()