268 lines
12 KiB
Python
268 lines
12 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.dates as mdates
|
|
import matplotlib.colors as mcolors
|
|
import matplotlib.gridspec as gridspec
|
|
|
|
from IPython import embed
|
|
from scipy import stats
|
|
import os
|
|
from IPython import embed
|
|
import helper_functions as hf
|
|
from params import *
|
|
import itertools
|
|
from statisitic_functions import *
|
|
|
|
if __name__ == '__main__':
|
|
###################################################################################################################
|
|
# parameter and variables
|
|
# plot params
|
|
inch = 2.45
|
|
save_path = '../../thesis/Figures/Results/'
|
|
|
|
# kernel
|
|
kernel_size = 120
|
|
kernel = np.ones(kernel_size) / kernel_size
|
|
|
|
# counter
|
|
stl_counter = 0
|
|
ax0_counter = 0
|
|
|
|
# transit: start, stop, dur
|
|
transit_data = []
|
|
# transit: time, trajectory
|
|
xy_transit = []
|
|
# transit: metadata: name, freq
|
|
transit_meta = []
|
|
|
|
# average speed
|
|
speed = []
|
|
|
|
###################################################################################################################
|
|
# load data
|
|
###################################################################################################################
|
|
|
|
for day_idx in [0, 1, 2, 3]:
|
|
|
|
filename = sorted(os.listdir('../data/'))[day_idx]
|
|
aifl = np.load('../data/' + filename + '/aifl2.npy', allow_pickle=True)
|
|
|
|
all_xticks = np.load('../data/' + filename + '/all_xtickses.npy', allow_pickle=True)
|
|
all_max_ch_means = np.load('../data/' + filename + '/all_max_ch.npy', allow_pickle=True)
|
|
power_means = np.load('../data/' + filename + '/power_means.npy', allow_pickle=True)
|
|
all_hms = np.load('../data/' + filename + '/all_hms.npy', allow_pickle=True)
|
|
freq = np.load('../data/' + filename + '/fish_freq_q10.npy', allow_pickle=True)
|
|
names = np.load('../data/' + filename + '/fish_species.npy', allow_pickle=True)
|
|
###############################################################################################################
|
|
# lists
|
|
fish_in_aifl = list(np.unique(np.where(~np.isnan(aifl[:, :, 0]))[1]))
|
|
# directs = []
|
|
flat_x = np.unique(np.hstack(all_xticks))
|
|
|
|
###############################################################################################################
|
|
# analysis of the changes and trajectories
|
|
for fish_number in range(len(power_means)):
|
|
if power_means[fish_number] >= -90.0:
|
|
|
|
x_tickses = all_xticks[fish_number]
|
|
max_ch_mean = all_max_ch_means[fish_number]
|
|
|
|
# smoothing of max channel mean
|
|
kernel = np.ones(kernel_size) / kernel_size
|
|
smooth_mcm = np.convolve(max_ch_mean, kernel, 'valid')
|
|
|
|
smooth_x = x_tickses[int(np.ceil(kernel_size / 2)):-int(np.floor(kernel_size / 2) - 1)]
|
|
|
|
# interpolate
|
|
fish_x = flat_x[np.where(flat_x == x_tickses[0])[0][0]:np.where(flat_x == x_tickses[-1])[0][0] + 1]
|
|
try:
|
|
trajectory = np.round(np.interp(fish_x, smooth_x, smooth_mcm))
|
|
except:
|
|
continue
|
|
|
|
# trial duration
|
|
t_s = datetime.timedelta(np.diff([fish_x[0], fish_x[-1]])[0])
|
|
trial_dur = t_s.total_seconds() / 60
|
|
|
|
# average speed
|
|
speed.append(np.sum(np.abs(np.diff(trajectory))) / trial_dur)
|
|
|
|
# activity vs. time
|
|
t = np.array(all_hms[fish_number] / 60 / 60)
|
|
|
|
# swim through
|
|
first = fish_x[0]
|
|
last = fish_x[-1]
|
|
|
|
start15 = mdates.date2num(mdates.num2date(first) + datetime.timedelta(seconds=2 * 60))
|
|
stop15 = mdates.date2num(mdates.num2date(last) - datetime.timedelta(seconds=2 * 60))
|
|
|
|
y = trajectory
|
|
# t = np.roll(t, int(len(np.unique(np.hstack(all_hms)))/2))
|
|
|
|
if np.any(trajectory[fish_x < start15] >= 13) and np.any(trajectory[fish_x > stop15] <= 2):
|
|
|
|
xy_transit.append(np.array([t, y]))
|
|
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
|
|
transit_meta.append([names[fish_number], freq[fish_number,2]])
|
|
|
|
elif np.any(trajectory[fish_x < start15] <= 2) and np.any(trajectory[fish_x > stop15] >= 13):
|
|
xy_transit.append(np.array([t, y]))
|
|
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
|
|
transit_meta.append([names[fish_number], freq[fish_number,2]])
|
|
|
|
elif np.any(trajectory[fish_x < start15] <= 1) \
|
|
and np.any(trajectory[fish_x > stop15] >= 6) \
|
|
and np.any(trajectory[fish_x > stop15] <= 7):
|
|
|
|
xy_transit.append(np.array([t, y]))
|
|
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
|
|
transit_meta.append([names[fish_number], freq[fish_number,2]])
|
|
|
|
elif np.any(trajectory[fish_x < start15] >= 6) \
|
|
and np.any(trajectory[fish_x < start15] <= 7) \
|
|
and np.any(trajectory[fish_x > stop15] <= 1):
|
|
|
|
xy_transit.append(np.array([t, y]))
|
|
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
|
|
transit_meta.append([names[fish_number], freq[fish_number,2]])
|
|
|
|
elif np.any(trajectory[fish_x < start15] >= 14) \
|
|
and np.any(trajectory[fish_x > stop15] >= 8) \
|
|
and np.any(trajectory[fish_x > stop15] <= 9):
|
|
|
|
xy_transit.append(np.array([t, y]))
|
|
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
|
|
transit_meta.append([names[fish_number], freq[fish_number,2]])
|
|
|
|
elif np.any(trajectory[fish_x < start15] >= 8) \
|
|
and np.any(trajectory[fish_x < start15] <= 9) \
|
|
and np.any(trajectory[fish_x > stop15] >= 14):
|
|
|
|
xy_transit.append(np.array([t, y]))
|
|
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
|
|
transit_meta.append([names[fish_number], freq[fish_number,2]])
|
|
|
|
else:
|
|
continue
|
|
|
|
transit_data = np.array(transit_data)
|
|
|
|
###############################################################################################################
|
|
# figure 1: time points of transit fish and trajectories
|
|
fig = plt.figure(constrained_layout=True, figsize=[15 / inch, 12 / inch])
|
|
gs = gridspec.GridSpec(ncols=2, nrows=3, figure=fig, hspace=0.05, wspace=0.0,
|
|
height_ratios=[1, 1, 1], left=0.1, bottom=0.15, right=0.95, top=0.95)
|
|
|
|
ax0 = fig.add_subplot(gs[0, :])
|
|
ax1 = fig.add_subplot(gs[1, 0])
|
|
ax2 = fig.add_subplot(gs[1, 1])
|
|
ax3 = fig.add_subplot(gs[2, 0])
|
|
ax4 = fig.add_subplot(gs[2, 1])
|
|
|
|
###############################################################################################################
|
|
# plotting
|
|
color_list = [0, 1, 4, 6]
|
|
for ploti in range(len(transit_data)):
|
|
if transit_data[ploti, 2] < 3 * 60:
|
|
# colori = color_list[int(transit_data[ploti, 3])]
|
|
|
|
if transit_meta[ploti][0] == 'Eigenmannia':
|
|
colori = 0
|
|
elif transit_meta[ploti][1] < 750:
|
|
colori = 1
|
|
else:
|
|
colori = 2
|
|
|
|
# plot time when fish is there
|
|
if transit_data[ploti, 0]<12:
|
|
roll_factor = 12
|
|
else:
|
|
roll_factor = -12
|
|
ax0.plot(transit_data[ploti, :2] + roll_factor, [ax0_counter, ax0_counter], color=color_efm[colori], lw=3,
|
|
label=labels[colori])
|
|
ax0_counter += 1
|
|
if transit_data[ploti + 1, 3] != transit_data[ploti, 3]:
|
|
ax0.plot([0, 24], [ax0_counter, ax0_counter], color=color_diffdays[0], lw=1, linestyle='dashed')
|
|
print(ax0_counter, transit_data[ploti, 3], transit_data[ploti+1, 3])
|
|
ax0_counter += 1
|
|
|
|
# plot trajectories
|
|
if transit_data[ploti, 0] > 1 and transit_data[ploti, 1] <= 3.5:
|
|
ax2.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori])
|
|
ax2.set_xlim([1, 3.1])
|
|
ax2.set_xticks([1,2,3])
|
|
ax2.set_xticklabels(['01:00', '02:00', '03:00'])
|
|
|
|
elif transit_data[ploti, 0] > 2.5 and transit_data[ploti, 1] <= 5:
|
|
ax4.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori])
|
|
ax4.set_xlim([3, 5])
|
|
ax4.set_xticks([3,4,5])
|
|
ax4.set_xticklabels(['03:00', '04:00', '05:00'])
|
|
|
|
elif transit_data[ploti, 0] > 18 and transit_data[ploti, 1] <= 20:
|
|
ax1.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori])
|
|
ax1.set_xlim([18, 20])
|
|
ax1.set_xticks([18,19,20])
|
|
ax1.set_xticklabels(['18:00', '19:00', '20:00'])
|
|
|
|
elif transit_data[ploti, 0] > 20 and transit_data[ploti, 1] <= 22:
|
|
ax3.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori])
|
|
ax3.set_xlim([20, 22])
|
|
ax3.set_xticks([20,21,22])
|
|
ax3.set_xticklabels(['20:00', '21:00', '22:00'])
|
|
|
|
#################################################################################################################
|
|
# nice axis
|
|
tagx = [-0.07, -0.17, -0.07, -0.17, -0.07]
|
|
tagy = [0.95,1.05,1.05,1.05,1.05]
|
|
for ax_idx, axis in enumerate([ax0,ax1,ax2,ax3,ax4]):
|
|
axis.text(tagx[ax_idx], tagy[ax_idx], chr(ord('A') + ax_idx), transform=axis.transAxes, fontsize='large')
|
|
axis.make_nice_ax()
|
|
axis.invert_yaxis()
|
|
if axis != ax0:
|
|
axis.axhline(7.5, xmin=0, xmax=15, color='white', lw=2)
|
|
# ax.set_yticklabels([1, 3, 5, 7, 14, 16, 18, 20])
|
|
axis.set_yticks([0, 1, 2, 3, 4, 5, 6, 7, 7.5, 8, 9, 10, 11, 12, 13, 14, 15])
|
|
axis.set_yticklabels([])
|
|
if axis == ax1 or axis == ax3:
|
|
axis.set_yticklabels(['1', '', '', '', '5', '', '', '', 'gap', '', '', '', '12', '', '', '', '16'])
|
|
axis.set_ylabel('Electrode', fontsize=11)
|
|
|
|
if axis == ax3 or axis == ax4:
|
|
axis.set_xlabel('Time', fontsize=11)
|
|
|
|
daytagy = [0.1, 0.5, 0.9]
|
|
for dayi in [0,1,2]:
|
|
ax0.text(0.02, daytagy[dayi], 'Day '+ str(dayi), transform=ax0.transAxes, fontsize='small', va='center', ha='left')
|
|
|
|
ax0.set_ylabel('Fish')
|
|
ax0.set_yticks([])
|
|
ax0.set_xlim([-0.1, 24.1])
|
|
ax0.set_xticks(np.arange(0, 25, 3))
|
|
ax0.set_xticklabels(['12:00', '15:00', '18:00', '21:00', '00:00', '03:00', '06:00', '09:00', '12:00'])
|
|
|
|
fig.align_ylabels()
|
|
# handles, labels = ax0.get_legend_handles_labels()
|
|
# unique = [(h, l) for lwdl_idx, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:lwdl_idx]]
|
|
# ax0.legend(*zip(*unique), bbox_to_anchor=(1, 1), loc="upper right", bbox_transform=fig.transFigure, ncol=1)
|
|
# fig.legend()
|
|
fig.savefig(save_path + 'transit.pdf')
|
|
plt.show()
|
|
|
|
|
|
|
|
###############################################################################################################
|
|
# speed
|
|
c = np.load('../data/all_changes.npy', allow_pickle=True)
|
|
stl = np.load('../data/stl.npy', allow_pickle=True)
|
|
|
|
speeds = []
|
|
for ploti in range(len(transit_data)):
|
|
if transit_data[ploti, 2] < 3 * 60:
|
|
speeds.append((np.max(xy_transit[ploti][1])-np.min(xy_transit[ploti][1]))/transit_data[ploti, 2])
|
|
|
|
min =(np.mean(speeds)-np.std(speeds))*60*12
|
|
max =(np.mean(speeds)+np.std(speeds))*60*12/1000
|
|
embed()
|