line_tracking_of_fish_movement/plot_transit_fish.py

268 lines
12 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from IPython import embed
from scipy import stats
import os
from IPython import embed
import helper_functions as hf
from params import *
import itertools
from statisitic_functions import *
if __name__ == '__main__':
###################################################################################################################
# parameter and variables
# plot params
inch = 2.45
save_path = '../../thesis/Figures/Results/'
# kernel
kernel_size = 120
kernel = np.ones(kernel_size) / kernel_size
# counter
stl_counter = 0
ax0_counter = 0
# transit: start, stop, dur
transit_data = []
# transit: time, trajectory
xy_transit = []
# transit: metadata: name, freq
transit_meta = []
# average speed
speed = []
###################################################################################################################
# load data
###################################################################################################################
for day_idx in [0, 1, 2, 3]:
filename = sorted(os.listdir('../data/'))[day_idx]
aifl = np.load('../data/' + filename + '/aifl2.npy', allow_pickle=True)
all_xticks = np.load('../data/' + filename + '/all_xtickses.npy', allow_pickle=True)
all_max_ch_means = np.load('../data/' + filename + '/all_max_ch.npy', allow_pickle=True)
power_means = np.load('../data/' + filename + '/power_means.npy', allow_pickle=True)
all_hms = np.load('../data/' + filename + '/all_hms.npy', allow_pickle=True)
freq = np.load('../data/' + filename + '/fish_freq_q10.npy', allow_pickle=True)
names = np.load('../data/' + filename + '/fish_species.npy', allow_pickle=True)
###############################################################################################################
# lists
fish_in_aifl = list(np.unique(np.where(~np.isnan(aifl[:, :, 0]))[1]))
# directs = []
flat_x = np.unique(np.hstack(all_xticks))
###############################################################################################################
# analysis of the changes and trajectories
for fish_number in range(len(power_means)):
if power_means[fish_number] >= -90.0:
x_tickses = all_xticks[fish_number]
max_ch_mean = all_max_ch_means[fish_number]
# smoothing of max channel mean
kernel = np.ones(kernel_size) / kernel_size
smooth_mcm = np.convolve(max_ch_mean, kernel, 'valid')
smooth_x = x_tickses[int(np.ceil(kernel_size / 2)):-int(np.floor(kernel_size / 2) - 1)]
# interpolate
fish_x = flat_x[np.where(flat_x == x_tickses[0])[0][0]:np.where(flat_x == x_tickses[-1])[0][0] + 1]
try:
trajectory = np.round(np.interp(fish_x, smooth_x, smooth_mcm))
except:
continue
# trial duration
t_s = datetime.timedelta(np.diff([fish_x[0], fish_x[-1]])[0])
trial_dur = t_s.total_seconds() / 60
# average speed
speed.append(np.sum(np.abs(np.diff(trajectory))) / trial_dur)
# activity vs. time
t = np.array(all_hms[fish_number] / 60 / 60)
# swim through
first = fish_x[0]
last = fish_x[-1]
start15 = mdates.date2num(mdates.num2date(first) + datetime.timedelta(seconds=2 * 60))
stop15 = mdates.date2num(mdates.num2date(last) - datetime.timedelta(seconds=2 * 60))
y = trajectory
# t = np.roll(t, int(len(np.unique(np.hstack(all_hms)))/2))
if np.any(trajectory[fish_x < start15] >= 13) and np.any(trajectory[fish_x > stop15] <= 2):
xy_transit.append(np.array([t, y]))
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
transit_meta.append([names[fish_number], freq[fish_number,2]])
elif np.any(trajectory[fish_x < start15] <= 2) and np.any(trajectory[fish_x > stop15] >= 13):
xy_transit.append(np.array([t, y]))
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
transit_meta.append([names[fish_number], freq[fish_number,2]])
elif np.any(trajectory[fish_x < start15] <= 1) \
and np.any(trajectory[fish_x > stop15] >= 6) \
and np.any(trajectory[fish_x > stop15] <= 7):
xy_transit.append(np.array([t, y]))
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
transit_meta.append([names[fish_number], freq[fish_number,2]])
elif np.any(trajectory[fish_x < start15] >= 6) \
and np.any(trajectory[fish_x < start15] <= 7) \
and np.any(trajectory[fish_x > stop15] <= 1):
xy_transit.append(np.array([t, y]))
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
transit_meta.append([names[fish_number], freq[fish_number,2]])
elif np.any(trajectory[fish_x < start15] >= 14) \
and np.any(trajectory[fish_x > stop15] >= 8) \
and np.any(trajectory[fish_x > stop15] <= 9):
xy_transit.append(np.array([t, y]))
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
transit_meta.append([names[fish_number], freq[fish_number,2]])
elif np.any(trajectory[fish_x < start15] >= 8) \
and np.any(trajectory[fish_x < start15] <= 9) \
and np.any(trajectory[fish_x > stop15] >= 14):
xy_transit.append(np.array([t, y]))
transit_data.append(np.array([t[0], t[-1], trial_dur, day_idx]))
transit_meta.append([names[fish_number], freq[fish_number,2]])
else:
continue
transit_data = np.array(transit_data)
###############################################################################################################
# figure 1: time points of transit fish and trajectories
fig = plt.figure(constrained_layout=True, figsize=[15 / inch, 12 / inch])
gs = gridspec.GridSpec(ncols=2, nrows=3, figure=fig, hspace=0.05, wspace=0.0,
height_ratios=[1, 1, 1], left=0.1, bottom=0.15, right=0.95, top=0.95)
ax0 = fig.add_subplot(gs[0, :])
ax1 = fig.add_subplot(gs[1, 0])
ax2 = fig.add_subplot(gs[1, 1])
ax3 = fig.add_subplot(gs[2, 0])
ax4 = fig.add_subplot(gs[2, 1])
###############################################################################################################
# plotting
color_list = [0, 1, 4, 6]
for ploti in range(len(transit_data)):
if transit_data[ploti, 2] < 3 * 60:
# colori = color_list[int(transit_data[ploti, 3])]
if transit_meta[ploti][0] == 'Eigenmannia':
colori = 0
elif transit_meta[ploti][1] < 750:
colori = 1
else:
colori = 2
# plot time when fish is there
if transit_data[ploti, 0]<12:
roll_factor = 12
else:
roll_factor = -12
ax0.plot(transit_data[ploti, :2] + roll_factor, [ax0_counter, ax0_counter], color=color_efm[colori], lw=3,
label=labels[colori])
ax0_counter += 1
if transit_data[ploti + 1, 3] != transit_data[ploti, 3]:
ax0.plot([0, 24], [ax0_counter, ax0_counter], color=color_diffdays[0], lw=1, linestyle='dashed')
print(ax0_counter, transit_data[ploti, 3], transit_data[ploti+1, 3])
ax0_counter += 1
# plot trajectories
if transit_data[ploti, 0] > 1 and transit_data[ploti, 1] <= 3.5:
ax2.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori])
ax2.set_xlim([1, 3.1])
ax2.set_xticks([1,2,3])
ax2.set_xticklabels(['01:00', '02:00', '03:00'])
elif transit_data[ploti, 0] > 2.5 and transit_data[ploti, 1] <= 5:
ax4.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori])
ax4.set_xlim([3, 5])
ax4.set_xticks([3,4,5])
ax4.set_xticklabels(['03:00', '04:00', '05:00'])
elif transit_data[ploti, 0] > 18 and transit_data[ploti, 1] <= 20:
ax1.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori])
ax1.set_xlim([18, 20])
ax1.set_xticks([18,19,20])
ax1.set_xticklabels(['18:00', '19:00', '20:00'])
elif transit_data[ploti, 0] > 20 and transit_data[ploti, 1] <= 22:
ax3.plot(xy_transit[ploti][0], xy_transit[ploti][1], color=color_efm[colori])
ax3.set_xlim([20, 22])
ax3.set_xticks([20,21,22])
ax3.set_xticklabels(['20:00', '21:00', '22:00'])
#################################################################################################################
# nice axis
tagx = [-0.07, -0.17, -0.07, -0.17, -0.07]
tagy = [0.95,1.05,1.05,1.05,1.05]
for ax_idx, axis in enumerate([ax0,ax1,ax2,ax3,ax4]):
axis.text(tagx[ax_idx], tagy[ax_idx], chr(ord('A') + ax_idx), transform=axis.transAxes, fontsize='large')
axis.make_nice_ax()
axis.invert_yaxis()
if axis != ax0:
axis.axhline(7.5, xmin=0, xmax=15, color='white', lw=2)
# ax.set_yticklabels([1, 3, 5, 7, 14, 16, 18, 20])
axis.set_yticks([0, 1, 2, 3, 4, 5, 6, 7, 7.5, 8, 9, 10, 11, 12, 13, 14, 15])
axis.set_yticklabels([])
if axis == ax1 or axis == ax3:
axis.set_yticklabels(['1', '', '', '', '5', '', '', '', 'gap', '', '', '', '12', '', '', '', '16'])
axis.set_ylabel('Electrode', fontsize=11)
if axis == ax3 or axis == ax4:
axis.set_xlabel('Time', fontsize=11)
daytagy = [0.1, 0.5, 0.9]
for dayi in [0,1,2]:
ax0.text(0.02, daytagy[dayi], 'Day '+ str(dayi), transform=ax0.transAxes, fontsize='small', va='center', ha='left')
ax0.set_ylabel('Fish')
ax0.set_yticks([])
ax0.set_xlim([-0.1, 24.1])
ax0.set_xticks(np.arange(0, 25, 3))
ax0.set_xticklabels(['12:00', '15:00', '18:00', '21:00', '00:00', '03:00', '06:00', '09:00', '12:00'])
fig.align_ylabels()
# handles, labels = ax0.get_legend_handles_labels()
# unique = [(h, l) for lwdl_idx, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:lwdl_idx]]
# ax0.legend(*zip(*unique), bbox_to_anchor=(1, 1), loc="upper right", bbox_transform=fig.transFigure, ncol=1)
# fig.legend()
fig.savefig(save_path + 'transit.pdf')
plt.show()
###############################################################################################################
# speed
c = np.load('../data/all_changes.npy', allow_pickle=True)
stl = np.load('../data/stl.npy', allow_pickle=True)
speeds = []
for ploti in range(len(transit_data)):
if transit_data[ploti, 2] < 3 * 60:
speeds.append((np.max(xy_transit[ploti][1])-np.min(xy_transit[ploti][1]))/transit_data[ploti, 2])
min =(np.mean(speeds)-np.std(speeds))*60*12
max =(np.mean(speeds)+np.std(speeds))*60*12/1000
embed()