126 lines
4.3 KiB
Python
126 lines
4.3 KiB
Python
import sys
|
|
import os
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import matplotlib.pyplot as plt
|
|
from thunderfish.dataloader import open_data
|
|
from thunderfish.eodanalysis import eod_waveform
|
|
from IPython import embed
|
|
import matplotlib.gridspec as gridspec
|
|
from params import *
|
|
|
|
|
|
def unfilter(data, samplerate, cutoff):
|
|
"""
|
|
Apply inverse high-pass filter on data.
|
|
|
|
Assumes high-pass filter \\[ \\tau \\dot y = -y + \\tau \\dot x \\] has
|
|
been applied on the original data \\(x\\), where \\(\tau=(2\\pi
|
|
f_{cutoff})^{-1}\\) is the time constant of the filter. To recover \\(x\\)
|
|
the ODE \\[ \\tau \\dot x = y + \\tau \\dot y \\] is applied on the
|
|
filtered data \\(y\\).
|
|
|
|
Parameters:
|
|
-----------
|
|
data: ndarray
|
|
High-pass filtered original data.
|
|
samplerate: float
|
|
Sampling rate of `data` in Hertz.
|
|
cutoff: float
|
|
Cutoff frequency \\(f_{cutoff}\\) of the high-pass filter in Hertz.
|
|
|
|
Returns:
|
|
--------
|
|
data: ndarray
|
|
Recovered original data.
|
|
"""
|
|
tau = 0.5 / np.pi / cutoff
|
|
fac = tau * samplerate
|
|
data -= np.mean(data)
|
|
d0 = data[0]
|
|
x = d0
|
|
for k in range(len(data)):
|
|
d1 = data[k]
|
|
x += (d1 - d0) + d0 / fac
|
|
data[k] = x
|
|
d0 = d1
|
|
return data
|
|
|
|
|
|
def calc_mean_eod(t0, f, data, dt=10, unfilter=0):
|
|
channel_list = np.arange(data.channels)
|
|
samplerate = data.samplerate
|
|
|
|
start_i = t0 * samplerate
|
|
end_i = t0 * samplerate + dt * samplerate + 1
|
|
t = np.arange(0, dt, 1 / f)
|
|
|
|
mean_EODs = []
|
|
for c in channel_list:
|
|
mean_eod, eod_times = eod_waveform(data[start_i:end_i, c], samplerate, t, unfilter_cutoff=unfilter)
|
|
mean_EODs.append(mean_eod)
|
|
|
|
max_size = list(map(lambda x: np.max(x.T[1]) - np.min(x.T[1]), mean_EODs))
|
|
EOD = mean_EODs[np.argmax(max_size)]
|
|
|
|
return EOD, samplerate
|
|
|
|
|
|
def main(folder, filename):
|
|
# folder = path_to_files
|
|
data = open_data(os.path.join(folder, 'traces-grid1.raw'), -1, 60.0, 10.0)
|
|
|
|
power_means = np.load('../data/' + filename + '/power_means.npy', allow_pickle=True)
|
|
all_q10 = np.load('../data/' + filename + '/fish_freq_q10.npy', allow_pickle=True)
|
|
all_t = np.load('../data/' + filename + '/eod_times_new_new.npy', allow_pickle=True)
|
|
all_f = np.load('../data/' + filename + '/eod_freq_new_new.npy', allow_pickle=True)
|
|
|
|
plot_pannel = [16, 0]
|
|
cutoff_value = [200, 0]
|
|
y_ticks = [[-0.001, 0, 0.001, 0.0015], [-0.002, 0, 0.002]]
|
|
|
|
##################################################################################################################
|
|
# figure
|
|
fig = plt.figure(constrained_layout=True, figsize=[15 / inch, 6 / inch])
|
|
gs = gridspec.GridSpec(ncols=2, nrows=1, figure=fig, hspace=0.05, wspace=0.0,
|
|
left=0.1, bottom=0.15, right=0.95, top=0.98)
|
|
|
|
ax2 = fig.add_subplot(gs[0, 1])
|
|
ax1 = fig.add_subplot(gs[0, 0], sharey=ax2)
|
|
|
|
for fn_idx, fish_number, ax in zip([0, 1], [15, 22], [ax1, ax2]):
|
|
print(all_q10[fish_number, 2], fish_number)
|
|
|
|
t = all_t[fish_number][plot_pannel[fn_idx]]
|
|
f = all_f[fish_number][plot_pannel[fn_idx]]
|
|
EOD, samplingrate = calc_mean_eod(t, f, data, unfilter=cutoff_value[fn_idx])
|
|
|
|
##############################################################################################################
|
|
# plot
|
|
ax.plot(EOD.T[0], EOD.T[1], color=color_efm[fn_idx], lw=2)
|
|
ax.fill_between(EOD.T[0], EOD.T[1] + EOD.T[2], EOD.T[1] - EOD.T[2],
|
|
color=color_efm[fn_idx], alpha=0.7)
|
|
ax.make_nice_ax()
|
|
|
|
ax.text(-0.12, 0.95, chr(ord('A') + fn_idx), transform=ax.transAxes, fontsize='large')
|
|
ax.text(0.8, 0.95, str(np.round(all_q10[fish_number, 2], 1))+' Hz', transform=ax.transAxes, fontsize=10)
|
|
|
|
ax.set_xlabel('Time')
|
|
ax.set_yticks([0])
|
|
ax.set_xticks([])
|
|
# fig.suptitle(all_q10[fish_number, 2])
|
|
|
|
ax1.set_ylabel('Amplitude')
|
|
fig.savefig(save_path + 'eod_waves.pdf')
|
|
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
for index, filename_idx in enumerate([2]):
|
|
filename = sorted(os.listdir('../../../data/mount_data/sanmartin/softgrid_1x16/'))[filename_idx]
|
|
folder = '../../../data/mount_data/sanmartin/softgrid_1x16/' + filename
|
|
print('new file: ' + filename)
|
|
main(folder, filename)
|