transfere bboxes to time-frequency points
This commit is contained in:
parent
0f0128439e
commit
93269a96a1
103
extract_from_bbox.py
Normal file
103
extract_from_bbox.py
Normal file
@ -0,0 +1,103 @@
|
||||
import pathlib
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from IPython import embed
|
||||
|
||||
def extract_time_freq_range_from_filename(img_path):
|
||||
file_name_str, time_span_str, freq_span_str = str(img_path.with_suffix('').name).split('__')
|
||||
time_span_str = time_span_str.replace('s', '')
|
||||
freq_span_str = freq_span_str.replace('Hz', '')
|
||||
|
||||
t0, t1 = np.array(time_span_str.split('-'), dtype=float)
|
||||
f0, f1 = np.array(freq_span_str.split('-'), dtype=float)
|
||||
return file_name_str, t0, t1, f0, f1
|
||||
|
||||
def bbox_to_data(img_path, t_min, t_max, f_min, f_max):
|
||||
label_path = img_path.parent.parent / 'labels' / img_path.with_suffix('.txt').name
|
||||
|
||||
annotations = np.loadtxt(label_path, delimiter=' ')
|
||||
if len(annotations.shape) == 1:
|
||||
annotations = np.array([annotations])
|
||||
if annotations.shape[1] == 0:
|
||||
print('no rises detected in this window')
|
||||
return []
|
||||
|
||||
|
||||
boxes = np.array([[x[1] - x[3] / 2, 1 - (x[2] + x[4] / 2), x[1] + x[3] / 2, 1 - (x[2] - x[4] / 2)] for x in annotations]) # x0, y0, x1, y1
|
||||
boxes[:, 0] = boxes[:, 0] * (t_max - t_min) + t_min
|
||||
boxes[:, 2] = boxes[:, 2] * (t_max - t_min) + t_min
|
||||
|
||||
boxes[:, 1] = boxes[:, 1] * (f_max - f_min) + f_min
|
||||
boxes[:, 3] = boxes[:, 3] * (f_max - f_min) + f_min
|
||||
|
||||
return boxes
|
||||
|
||||
def load_wavetracker_data(raw_path):
|
||||
fund_v = np.load(raw_path.parent / 'fund_v.npy')
|
||||
ident_v = np.load(raw_path.parent / 'ident_v.npy')
|
||||
idx_v = np.load(raw_path.parent / 'idx_v.npy')
|
||||
times = np.load(raw_path.parent / 'times.npy')
|
||||
|
||||
return fund_v, ident_v, idx_v, times
|
||||
|
||||
|
||||
def assign_rises_to_ids(raw_path, time_frequency_bboxes):
|
||||
fund_v, ident_v, idx_v, times = load_wavetracker_data(raw_path)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(times[idx_v[~np.isnan(ident_v)]], fund_v[~np.isnan(ident_v)], '.')
|
||||
|
||||
mask = time_frequency_bboxes['file_name'] == raw_path.parent.name
|
||||
for index, bbox in time_frequency_bboxes[mask].iterrows():
|
||||
name, t0, f0, t1, f1 = (bbox[0], *bbox[1:].astype(float))
|
||||
ax.add_patch(
|
||||
Rectangle((t0, f0),
|
||||
(t1 - t0),
|
||||
(f1 - f0),
|
||||
fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10)
|
||||
)
|
||||
plt.show()
|
||||
|
||||
# ToDo: eliminate double rises -- overlap
|
||||
# ToDo: double detections -- non overlap --> the one with higher probability ?!
|
||||
# ToDo: assign rises to traces --> who is at lower right corner
|
||||
|
||||
def main(args):
|
||||
img_paths = sorted(list(pathlib.Path(args.annotations).absolute().rglob('*.png')))
|
||||
df_collect = []
|
||||
|
||||
for img_path in img_paths:
|
||||
# convert to time_frequency
|
||||
file_name_str, t_min, t_max, f_min, f_max = extract_time_freq_range_from_filename(img_path)
|
||||
|
||||
boxes = bbox_to_data(img_path, t_min, t_max, f_min, f_max ) # t0, t1, f0, f1
|
||||
|
||||
# store values in df
|
||||
if not len(boxes) == 0:
|
||||
for x0, y0, x1, y1 in boxes:
|
||||
df_collect.append([file_name_str, x0, y0, x1, y1])
|
||||
|
||||
time_frequency_bboxes = pd.DataFrame(data= np.array(df_collect), columns=['file_name', 't0', 'f0', 't1', 'f1'])
|
||||
|
||||
if args.tracking_data_path:
|
||||
file_paths = sorted(list(pathlib.Path(args.tracking_data_path).absolute().rglob('*.raw')))
|
||||
for raw_path in file_paths:
|
||||
if not raw_path.parent.name in time_frequency_bboxes['file_name'].to_list():
|
||||
continue
|
||||
|
||||
assign_rises_to_ids(raw_path, time_frequency_bboxes)
|
||||
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Extract time, frequency and identity association of bboxes')
|
||||
parser.add_argument('annotations', nargs='?', type=str, help='path to annotations')
|
||||
parser.add_argument('-t', '--tracking_data_path', type=str, help='path to tracking dataa')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Loading…
Reference in New Issue
Block a user