added miussing dataset generator

This commit is contained in:
Till Raab 2023-11-07 16:00:14 +01:00
parent 9dcf54c139
commit 92ab342a65

324
generate_dataset.py Normal file
View File

@ -0,0 +1,324 @@
import itertools
import sys
import os
import argparse
import torch
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from IPython import embed
from confic import (MIN_FREQ, MAX_FREQ, DELTA_FREQ, FREQ_OVERLAP, DELTA_TIME, TIME_OVERLAP, IMG_SIZE, IMG_DPI, DATA_DIR, LABEL_DIR)
def load_spec_data(folder: str):
"""
Load spectrogram of a given electrode-grid recording generated with the wavetracker package. The spectrograms may
be to large to load in total, thats why memmory mapping is used (numpy.memmap).
Parameters
----------
folder: str
Folder where fine spec numpy files generated for grid recordings with the wavetracker package can be found.
Returns
-------
fill_freqs: ndarray
Freuqencies corresponding to 1st dimension of the spectrogram.
fill_times: ndarray
Times corresponding to the 2nd dimenstion if the spectrigram.
fill_spec: ndarray
Spectrigram of the recording refered to in the input folder.
"""
fill_freqs, fill_times, fill_spec = [], [], []
if os.path.exists(os.path.join(folder, 'fill_spec.npy')):
fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy'))
fill_times = np.load(os.path.join(folder, 'fill_times.npy'))
fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy'))
fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r',
shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F')
elif os.path.exists(os.path.join(folder, 'fine_spec.npy')):
fill_freqs = np.load(os.path.join(folder, 'fine_freqs.npy'))
fill_times = np.load(os.path.join(folder, 'fine_times.npy'))
fill_spec_shape = np.load(os.path.join(folder, 'fine_spec_shape.npy'))
fill_spec = np.memmap(os.path.join(folder, 'fine_spec.npy'), dtype='float', mode='r',
shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F')
return fill_freqs, fill_times, fill_spec
def load_tracking_data(folder):
base_path = Path(folder)
EODf_v = np.load(base_path / 'fund_v.npy')
ident_v = np.load(base_path / 'ident_v.npy')
idx_v = np.load(base_path / 'idx_v.npy')
times_v = np.load(base_path / 'times.npy')
return EODf_v, ident_v, idx_v, times_v
def load_trial_data(folder):
base_path = Path(folder)
fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy')
rise_idx = np.load(base_path / 'analysis' / 'rise_idx.npy')
rise_size = np.load(base_path / 'analysis' / 'rise_size.npy')
fish_baseline_freq = np.load(base_path / 'analysis' / 'baseline_freqs.npy')
fish_baseline_freq_time = np.load(base_path / 'analysis' / 'baseline_freq_times.npy')
return fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time
def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, dataset_folder):
f_res, t_res = freq[1] - freq[0], times[1] - times[0]
fig_title = (f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0')
fig = plt.figure(figsize=IMG_SIZE, num=fig_title)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
ax = fig.add_subplot(gs[0, 0])
ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res))
ax.axis(False)
# plt.savefig(os.path.join(dataset_folder, fig_title), dpi=IMG_DPI)
plt.savefig(Path(DATA_DIR)/fig_title, dpi=IMG_DPI)
plt.close()
return fig_title, (IMG_SIZE[0]*IMG_DPI, IMG_SIZE[1]*IMG_DPI)
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,
bbox_df, cols, width, height, t0, t1, f0, f1):
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
all_x_center = []
all_y_center = []
all_width = []
all_height= []
for id_idx in range(len(fish_freq)):
rise_idx_oi = np.array(rise_idx[id_idx][
(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)], dtype=int)
rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)]
if len(rise_idx_oi) == 0:
# np.savetxt(LABEL_DIR / Path(pic_save_str).with_suffix('.txt'), np.array([]))
continue
closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi]))
closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx]
upper_freq_bound = closest_baseline_freq + rise_size_oi
lower_freq_bound = closest_baseline_freq
left_time_bound = times_v[rise_idx_oi]
right_time_bound = np.zeros_like(left_time_bound)
for enu, Ct_oi in enumerate(times_v[rise_idx_oi]):
Crise_size = rise_size_oi[enu]
Cblf = closest_baseline_freq[enu]
rise_end_t = times_v[(times_v > Ct_oi) &
(fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
if len(rise_end_t) == 0:
right_time_bound[enu] = np.nan
else:
right_time_bound[enu] = rise_end_t[0]
mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.))
left_time_bound = left_time_bound[mask]
right_time_bound = right_time_bound[mask]
lower_freq_bound = lower_freq_bound[mask]
upper_freq_bound = upper_freq_bound[mask]
left_time_bound -= 0.01 * (t1 - t0)
right_time_bound += 0.05 * (t1 - t0)
lower_freq_bound -= 0.01 * (f1 - f0)
upper_freq_bound += 0.05 * (f1 - f0)
mask2 = ((left_time_bound >= t0) &
(right_time_bound <= t1) &
(lower_freq_bound >= f0) &
(upper_freq_bound <= f1)
)
left_time_bound = left_time_bound[mask2]
right_time_bound = right_time_bound[mask2]
lower_freq_bound = lower_freq_bound[mask2]
upper_freq_bound = upper_freq_bound[mask2]
if len(left_time_bound) == 0:
continue
# x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int)
# x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int)
#
# y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int)
# y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int)
rel_x0 = np.array((left_time_bound - t0) / (t1 - t0), dtype=float)
rel_x1 = np.array((right_time_bound - t0) / (t1 - t0), dtype=float)
rel_y0 = np.array(1 - (upper_freq_bound - f0) / (f1 - f0), dtype=float)
rel_y1 = np.array(1 - (lower_freq_bound - f0) / (f1 - f0), dtype=float)
rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2
rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2
rel_width = rel_x1 - rel_x0
rel_height = rel_y1 - rel_y0
all_x_center.extend(rel_x_center)
all_y_center.extend(rel_y_center)
all_width.extend(rel_width)
all_height.extend(rel_height)
# bbox = np.array([[pic_save_str for i in range(len(left_time_bound))],
# left_time_bound,
# right_time_bound,
# lower_freq_bound,
# upper_freq_bound,
# x0, y0, x1, y1])
bbox_yolo_style = np.array([
np.ones(len(all_x_center)),
all_x_center,
all_y_center,
all_width,
all_height
]).T
np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
return bbox_yolo_style
def main(args):
folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy'))
if not args.inference:
print('generate training dataset only for files with detected rises')
folders = [folder for folder in folders if (folder / 'analysis' / 'rise_idx.npy').exists()]
cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1']
bbox_df = pd.DataFrame(columns=cols)
else:
print('generate inference dataset ... only image output')
bbox_df = {}
for enu, folder in enumerate(folders):
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
# load different categories of data
freq, times, spec = (
load_spec_data(folder))
EODf_v, ident_v, idx_v, times_v = (
load_tracking_data(folder))
if not args.inference:
fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = (
load_trial_data(folder))
# generate iterator for analysis window loop
pic_base = tqdm(itertools.product(
np.arange(0, times[-1], DELTA_TIME),
np.arange(MIN_FREQ, MAX_FREQ, DELTA_FREQ)
),
total=int((((MAX_FREQ-MIN_FREQ)//DELTA_FREQ)+1) * ((times[-1] // DELTA_TIME)+1))
)
for t0, f0 in pic_base:
t1 = t0 + DELTA_TIME + TIME_OVERLAP
f1 = f0 + DELTA_FREQ + FREQ_OVERLAP
present_freqs = EODf_v[(~np.isnan(ident_v)) &
(t0 <= times_v[idx_v]) &
(times_v[idx_v] <= t1) &
(EODf_v >= f0) &
(EODf_v <= f1)]
if len(present_freqs) == 0:
continue
# get spec_idx for current spec snippet
f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1))
t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1))
# get spec snippet and create torch.tensfor from it
s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32)
log_s = torch.log10(s)
transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s))
s_trans = transformed(log_s.unsqueeze(0))
pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, args.dataset_folder)
if not args.inference:
bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size,
fish_baseline_freq_time, fish_baseline_freq,
pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1)
#######################################################################
if False:
if bbox_yolo_style.shape[0] >= 1:
f_res, t_res = freq[1] - freq[0], times[1] - times[0]
fig_title = (
f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(
' ', '0')
fig = plt.figure(figsize=IMG_SIZE, num=fig_title)
gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95) #
ax = fig.add_subplot(gs[0, 0])
ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res))
# ax.invert_yaxis()
# ax.axis(False)
for i in range(len(bbox_df)):
# Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32)
Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']]
ax.add_patch(
Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])),
float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600,
float(Cbbox['f1']) - float(Cbbox['f0']),
fill=False, color="white", linestyle='-', linewidth=2, zorder=10)
)
# print(bbox_yolo_style.T)
for bbox in bbox_yolo_style:
x0 = bbox[1] - bbox[3]/2 # x_center - width/2
y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2
w = bbox[3]
h = bbox[4]
ax.add_patch(
Rectangle((x0, y0), w, h,
fill=False, color="k", linestyle='--', linewidth=2, zorder=10,
transform=ax.transAxes)
)
plt.show()
#######################################################################
# if not args.inference:
# print('save bboxes')
# bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
parser.add_argument('folder', type=str, help='single recording analysis', default='')
parser.add_argument('-d', "--dataset_folder", type=str, help='designated datasef folder', default=DATA_DIR)
parser.add_argument('-i', "--inference", action="store_true", help="generate inference dataset. Img only")
args = parser.parse_args()
if not Path(args.dataset_folder).exists():
Path(args.dataset_folder).mkdir(parents=True, exist_ok=True)
main(args)