efishSignalDetector/data/generate_dataset.py

243 lines
10 KiB
Python

import time
import numpy as np
import argparse
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
import itertools
import sys
import os
from IPython import embed
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
def load_spec_data(folder):
fill_freqs, fill_times, fill_spec = [], [], []
if os.path.exists(os.path.join(folder, 'fill_spec.npy')):
fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy'))
fill_times = np.load(os.path.join(folder, 'fill_times.npy'))
fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy'))
fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r',
shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F')
elif os.path.exists(os.path.join(folder, 'fine_spec.npy')):
fill_freqs = np.load(os.path.join(folder, 'fine_freqs.npy'))
fill_times = np.load(os.path.join(folder, 'fine_times.npy'))
fill_spec_shape = np.load(os.path.join(folder, 'fine_spec_shape.npy'))
fill_spec = np.memmap(os.path.join(folder, 'fine_spec.npy'), dtype='float', mode='r',
shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F')
return fill_freqs, fill_times, fill_spec
def load_tracking_data(folder):
base_path = Path(folder)
EODf_v = np.load(base_path / 'fund_v.npy')
ident_v = np.load(base_path / 'ident_v.npy')
idx_v = np.load(base_path / 'idx_v.npy')
times_v = np.load(base_path / 'times.npy')
return EODf_v, ident_v, idx_v, times_v
def load_trial_data(folder):
base_path = Path(folder)
fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy')
rise_idx = np.load(base_path / 'analysis' / 'rise_idx.npy')
rise_size = np.load(base_path / 'analysis' / 'rise_size.npy')
fish_baseline_freq = np.load(base_path / 'analysis' / 'baseline_freqs.npy')
fish_baseline_freq_time = np.load(base_path / 'analysis' / 'baseline_freq_times.npy')
return fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time
def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, dataset_folder):
size = (7, 7)
dpi = 256
f_res, t_res = freq[1] - freq[0], times[1] - times[0]
fig_title = (f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0')
fig = plt.figure(figsize=(7, 7), num=fig_title)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
ax = fig.add_subplot(gs[0, 0])
ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res))
ax.axis(False)
plt.savefig(os.path.join(dataset_folder, fig_title), dpi=256)
plt.close()
return fig_title, (size[0]*dpi, size[1]*dpi)
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,
bbox_df, cols, width, height, t0, t1, f0, f1):
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
for id_idx in range(len(fish_freq)):
rise_idx_oi = np.array(rise_idx[id_idx][
(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)], dtype=int)
rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)]
if len(rise_idx_oi) == 0:
continue
closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi]))
closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx]
upper_freq_bound = closest_baseline_freq + rise_size_oi
lower_freq_bound = closest_baseline_freq
left_time_bound = times_v[rise_idx_oi]
right_time_bound = np.zeros_like(left_time_bound)
for enu, Ct_oi in enumerate(times_v[rise_idx_oi]):
Crise_size = rise_size_oi[enu]
Cblf = closest_baseline_freq[enu]
rise_end_t = times_v[(times_v > Ct_oi) &
(fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
if len(rise_end_t) == 0:
right_time_bound[enu] = np.nan
else:
right_time_bound[enu] = rise_end_t[0]
mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.))
left_time_bound = left_time_bound[mask]
right_time_bound = right_time_bound[mask]
lower_freq_bound = lower_freq_bound[mask]
upper_freq_bound = upper_freq_bound[mask]
# dt_bbox = right_time_bound - left_time_bound
# df_bbox = upper_freq_bound - lower_freq_bound
left_time_bound -= 0.01 * (t1 - t0)
right_time_bound += 0.05 * (t1 - t0)
lower_freq_bound -= 0.01 * (f1 - f0)
upper_freq_bound += 0.05 * (f1 - f0)
mask2 = ((left_time_bound >= t0) &
(right_time_bound <= t1) &
(lower_freq_bound >= f0) &
(upper_freq_bound <= f1)
)
left_time_bound = left_time_bound[mask2]
right_time_bound = right_time_bound[mask2]
lower_freq_bound = lower_freq_bound[mask2]
upper_freq_bound = upper_freq_bound[mask2]
x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int)
x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int)
y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int)
y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int)
bbox = np.array([[pic_save_str for i in range(len(left_time_bound))],
left_time_bound,
right_time_bound,
lower_freq_bound,
upper_freq_bound,
x0, y0, x1, y1])
tmp_df = pd.DataFrame(
data=bbox.T,
columns=cols
)
bbox_df = pd.concat([bbox_df, tmp_df], ignore_index=True)
return bbox_df
def main(args):
# Hyperparameter
min_freq = 200
max_freq = 1500
d_freq = 200
freq_overlap = 25
d_time = 60*10
time_overlap = 60*1
folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy'))
if not args.inference:
print('generate training dataset only for files with detected rises')
folders = [folder for folder in folders if (folder / 'analysis' / 'rise_idx.npy').exists()]
cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1']
bbox_df = pd.DataFrame(columns=cols)
else:
print('generate inference dataset ... only image output')
bbox_df = {}
for enu, folder in enumerate(folders):
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
# load different categories of data
freq, times, spec = (
load_spec_data(folder))
EODf_v, ident_v, idx_v, times_v = (
load_tracking_data(folder))
if not args.inference:
fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = (
load_trial_data(folder))
# generate iterator for analysis window loop
pic_base = tqdm(itertools.product(
np.arange(0, times[-1], d_time),
np.arange(min_freq, max_freq, d_freq)
),
total=int((((max_freq-min_freq)//d_freq)+1) * ((times[-1] // d_time)+1))
)
for t0, f0 in pic_base:
t1 = t0 + d_time + time_overlap
f1 = f0 + d_freq + freq_overlap
present_freqs = EODf_v[(~np.isnan(ident_v)) &
(t0 <= times_v[idx_v]) &
(times_v[idx_v] <= t1) &
(EODf_v >= f0) &
(EODf_v <= f1)]
if len(present_freqs) == 0:
continue
# get spec_idx for current spec snippet
f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1))
t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1))
# get spec snippet and create torch.tensfor from it
s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32)
log_s = torch.log10(s)
transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s))
s_trans = transformed(log_s.unsqueeze(0))
pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, args.dataset_folder)
if not args.inference:
bbox_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size,
fish_baseline_freq_time, fish_baseline_freq,
pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1)
if not args.inference:
print('save bboxes')
bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
parser.add_argument('folder', type=str, help='single recording analysis', default='')
parser.add_argument('-d', "--dataset_folder", type=str, help='designated datasef folder', default='dataset')
parser.add_argument('-i', "--inference", action="store_true", help="generate inference dataset. Img only")
args = parser.parse_args()
if not Path(args.dataset_folder).exists():
Path(args.dataset_folder).mkdir(parents=True, exist_ok=True)
main(args)