diff --git a/data/generate_dataset.py b/data/generate_dataset.py index e72e0b6..69c6db7 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -1,28 +1,40 @@ -import time - -import numpy as np +import itertools +import sys +import os import argparse + import torch from torch import nn import torch.nn.functional as F import torchvision.transforms as T + +import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec -from pathlib import Path import pandas as pd - +from pathlib import Path from tqdm.auto import tqdm -import itertools -import sys -import os - -from IPython import embed -from matplotlib.patches import Rectangle -from matplotlib.collections import PatchCollection - -def load_spec_data(folder): +def load_spec_data(folder: str): + """ + Load spectrogram of a given electrode-grid recording generated with the wavetracker package. The spectrograms may + be to large to load in total, thats why memmory mapping is used (numpy.memmap). + + Parameters + ---------- + folder: str + Folder where fine spec numpy files generated for grid recordings with the wavetracker package can be found. + + Returns + ------- + fill_freqs: ndarray + Freuqencies corresponding to 1st dimension of the spectrogram. + fill_times: ndarray + Times corresponding to the 2nd dimenstion if the spectrigram. + fill_spec: ndarray + Spectrigram of the recording refered to in the input folder. + """ fill_freqs, fill_times, fill_spec = [], [], [] if os.path.exists(os.path.join(folder, 'fill_spec.npy')): @@ -41,6 +53,7 @@ def load_spec_data(folder): return fill_freqs, fill_times, fill_spec + def load_tracking_data(folder): base_path = Path(folder) EODf_v = np.load(base_path / 'fund_v.npy') @@ -50,6 +63,7 @@ def load_tracking_data(folder): return EODf_v, ident_v, idx_v, times_v + def load_trial_data(folder): base_path = Path(folder) fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy') @@ -79,6 +93,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, return fig_title, (size[0]*dpi, size[1]*dpi) + def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1): @@ -119,9 +134,6 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq lower_freq_bound = lower_freq_bound[mask] upper_freq_bound = upper_freq_bound[mask] - # dt_bbox = right_time_bound - left_time_bound - # df_bbox = upper_freq_bound - lower_freq_bound - left_time_bound -= 0.01 * (t1 - t0) right_time_bound += 0.05 * (t1 - t0) lower_freq_bound -= 0.01 * (f1 - f0)