efishSignalDetector/extract_from_bbox.py

347 lines
15 KiB
Python

import itertools
import pathlib
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image
from tqdm.auto import tqdm
from IPython import embed
def extract_time_freq_range_from_filename(img_path):
file_name_str, time_span_str, freq_span_str = str(img_path.with_suffix('').name).split('__')
time_span_str = time_span_str.replace('s', '')
freq_span_str = freq_span_str.replace('Hz', '')
t0, t1 = np.array(time_span_str.split('-'), dtype=float)
f0, f1 = np.array(freq_span_str.split('-'), dtype=float)
return file_name_str, t0, t1, f0, f1
def bbox_to_data(img_path, t_min, t_max, f_min, f_max):
label_path = img_path.parent.parent / 'labels' / img_path.with_suffix('.txt').name
annotations = np.loadtxt(label_path, delimiter=' ')
if len(annotations.shape) == 1:
annotations = np.array([annotations])
if annotations.shape[1] == 0:
print('no rises detected in this window')
return [], []
boxes = np.array([[x[1] - x[3] / 2, 1 - (x[2] + x[4] / 2), x[1] + x[3] / 2, 1 - (x[2] - x[4] / 2)] for x in annotations]) # x0, y0, x1, y1
boxes[:, 0] = boxes[:, 0] * (t_max - t_min) + t_min
boxes[:, 2] = boxes[:, 2] * (t_max - t_min) + t_min
boxes[:, 1] = boxes[:, 1] * (f_max - f_min) + f_min
boxes[:, 3] = boxes[:, 3] * (f_max - f_min) + f_min
scores = annotations[:, 5]
return boxes, scores
def load_wavetracker_data(raw_path):
fund_v = np.load(raw_path.parent / 'fund_v.npy')
ident_v = np.load(raw_path.parent / 'ident_v.npy')
idx_v = np.load(raw_path.parent / 'idx_v.npy')
times = np.load(raw_path.parent / 'times.npy')
return fund_v, ident_v, idx_v, times
def assign_rises_to_ids(raw_path, time_frequency_bboxes, bbox_groups):
def identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v):
mean_id_box_f_rel_to_bbox = []
for id in possible_ids:
id_box_f = fund_v[(ident_v == id) & (times[idx_v] >= t0) & (times[idx_v] <= t1)]
id_box_f_rel_to_bbox = (id_box_f - f0) / (f1 - f0)
mean_id_box_f_rel_to_bbox.append(np.mean(id_box_f_rel_to_bbox))
# print(id, np.mean(id_box_f), f0, f1, np.mean(id_box_f_rel_to_bbox))
most_likely_id = possible_ids[np.argsort(mean_id_box_f_rel_to_bbox)[0]]
return most_likely_id
fund_v, ident_v, idx_v, times = load_wavetracker_data(raw_path)
fig, ax = plt.subplots()
ax.plot(times[idx_v[~np.isnan(ident_v)]], fund_v[~np.isnan(ident_v)], '.')
mask = time_frequency_bboxes['file_name'] == raw_path.parent.name
for index, bbox in time_frequency_bboxes[mask].iterrows():
name, t0, f0, t1, f1, score = (bbox[0], *bbox[1:-2].astype(float))
if bbox_groups[index] == 0:
color = 'tab:green'
elif bbox_groups[index] > 0:
color = 'tab:red'
else:
color = 'k'
ax.add_patch(
Rectangle((t0, f0),
(t1 - t0),
(f1 - f0),
fill=False, color=color, linestyle='--', linewidth=2, zorder=10)
)
ax.text(t1, f1, f'{score:.1%}', ha='right', va='bottom')
possible_ids = np.unique(
ident_v[~np.isnan(ident_v) &
(t0 <= times[idx_v]) &
(t1 >= times[idx_v]) &
(f0 <= fund_v) &
(f1 >= fund_v)]
)
if len(possible_ids) == 1:
assigned_id = possible_ids[0]
time_frequency_bboxes.at[index, 'id'] = assigned_id
elif len(possible_ids) > 1:
assigned_id = identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v)
time_frequency_bboxes.at[index, 'id'] = assigned_id
# rise_id[index] = identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v)
else:
continue
rise_start_freq_th = f0 + (f1 - f0) * 0.37
wavetracker_mask = np.arange(len(fund_v))[
(times[idx_v] >= t0) &
(times[idx_v] <= t1) &
(ident_v == assigned_id)
]
if np.sum(fund_v[wavetracker_mask] > rise_start_freq_th) > 0:
# rise start time = moment where rise freq exceeds 37% of bbox freq range ...
rise_start_idx = wavetracker_mask[fund_v[wavetracker_mask] > rise_start_freq_th][0]
rise_time = times[idx_v[rise_start_idx]]
else:
### if this is never the case use the largest slope
rise_start_idx = wavetracker_mask[np.argmax(np.diff(fund_v[wavetracker_mask]))]
rise_time = times[idx_v[rise_start_idx]]
time_frequency_bboxes.at[index, 'event_time'] = rise_time
ax.plot(rise_time, fund_v[rise_start_idx], 'ok')
# embed()
# quit()
# time_frequency_bboxes['id'] = rise_id
# embed()
# plt.show()
plt.close()
return time_frequency_bboxes
def find_overlapping_bboxes(df_collect):
file_names = np.array(df_collect)[:, 0]
bboxes = np.array(df_collect)[:, 1:].astype(float)
overlap_bbox_idxs = []
for file_name in tqdm(np.unique(file_names)):
file_bbox_idxs = np.arange(len(file_names))[file_names == file_name]
for ind0, ind1 in itertools.combinations(file_bbox_idxs, r=2):
bb0 = bboxes[ind0]
bb1 = bboxes[ind1]
t0_0, f0_0, t0_1, f0_1 = bb0[:-1]
t1_0, f1_0, t1_1, f1_1 = bb1[:-1]
bb_times = np.array([t0_0, t0_1, t1_0, t1_1])
bb_time_associate = np.array([0, 0, 1, 1])
time_helper = bb_time_associate[np.argsort(bb_times)]
if time_helper[0] == time_helper[1]:
# no temporal overlap
continue
# check freq overlap
bb_freqs = np.array([f0_0, f0_1, f1_0, f1_1])
bb_freq_associate = np.array([0, 0, 1, 1])
freq_helper = bb_freq_associate[np.argsort(bb_freqs)]
if freq_helper[0] == freq_helper[1]:
continue
overlap_bbox_idxs.append((ind0, ind1))
return np.asarray(overlap_bbox_idxs)
def main(args):
img_paths = sorted(list(pathlib.Path(args.annotations).absolute().rglob('*.png')))
df_collect = []
for img_path in img_paths:
# convert to time_frequency
file_name_str, t_min, t_max, f_min, f_max = extract_time_freq_range_from_filename(img_path)
boxes, scores = bbox_to_data(img_path, t_min, t_max, f_min, f_max ) # t0, t1, f0, f1
# store values in df
if not len(boxes) == 0:
for (t0, f0, t1, f1), s in zip(boxes, scores):
df_collect.append([file_name_str, t0, f0, t1, f1, s])
df_collect = np.array(df_collect)
overlap_bbox_idxs = find_overlapping_bboxes(df_collect)
bbox_groups = delete_double_boxes(overlap_bbox_idxs, df_collect)
time_frequency_bboxes = pd.DataFrame(data= np.array(df_collect), columns=['file_name', 't0', 'f0', 't1', 'f1', 'score'])
time_frequency_bboxes['id'] = np.full(len(time_frequency_bboxes), np.nan)
time_frequency_bboxes['event_time'] = np.full(len(time_frequency_bboxes), np.nan)
###########################################
# for file_name in time_frequency_bboxes['file_name'].unique():
# fig, ax = plt.subplots()
#
# mask = time_frequency_bboxes['file_name'] == file_name
# for index, bbox in time_frequency_bboxes[mask].iterrows():
# name, t0, f0, t1, f1 = (bbox[0], *bbox[1:-1].astype(float))
# if bbox_groups[index] == 0:
# color = 'tab:green'
# elif bbox_groups[index] > 0:
# color = 'tab:red'
# else:
# color = 'k'
#
# ax.add_patch(
# Rectangle((t0, f0),
# (t1 - t0),
# (f1 - f0),
# fill=False, color=color, linestyle='--', linewidth=2, zorder=10)
# )
# # ax.set_xlim(float(time_frequency_bboxes[mask]['t0'].min()), float(time_frequency_bboxes[mask]['t1'].max()))
# ax.set_xlim(0, float(time_frequency_bboxes[mask]['t1'].max()))
# # ax.set_ylim(float(time_frequency_bboxes[mask]['f0'].min()), float(time_frequency_bboxes[mask]['f1'].max()))
# ax.set_ylim(400, 1200)
# plt.show()
# exit()
###########################################
if args.tracking_data_path:
file_paths = sorted(list(pathlib.Path(args.tracking_data_path).absolute().rglob('*.raw')))
for raw_path in file_paths:
if not raw_path.parent.name in time_frequency_bboxes['file_name'].to_list():
continue
time_frequency_bboxes = assign_rises_to_ids(raw_path, time_frequency_bboxes, bbox_groups)
for raw_path in file_paths:
# mask = (time_frequency_bboxes['file_name'] == raw_path.parent.name)
mask = ((time_frequency_bboxes['file_name'] == raw_path.parent.name) & (~np.isnan(time_frequency_bboxes['id'])))
save_df = pd.DataFrame(time_frequency_bboxes[mask][['t0', 't1', 'f0', 'f1', 'score', 'id', 'event_time']].values, columns=['t0', 't1', 'f0', 'f1', 'score', 'id', 'event_time'])
save_df['label'] = np.ones(len(save_df), dtype=int)
save_df.to_csv(raw_path.parent / 'risedetector_bboxes.csv', sep = ',', index = False)
quit()
def delete_double_boxes(overlap_bbox_idxs, df_collect, overlap_th = 0.2):
def get_connected(non_regarded_bbox_idx, overlap_bbox_idxs):
mask = np.array((np.array(overlap_bbox_idxs) == non_regarded_bbox_idx).sum(1), dtype=bool)
affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask])
return affected_bbox_idxs
handled_bbox_idxs = []
bbox_groups = np.zeros(len(df_collect))
# detele_bbox_idxs = []
for Coverlapping_bbox_idx in tqdm(np.unique(overlap_bbox_idxs)):
if Coverlapping_bbox_idx in handled_bbox_idxs:
continue
regarded_bbox_idxs = [Coverlapping_bbox_idx]
mask = np.array((np.array(overlap_bbox_idxs) == Coverlapping_bbox_idx).sum(1), dtype=bool)
affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask])
non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs))
# non_regarded_bbox_idxs = list(set(non_regarded_bbox_idxs) - set(handled_bbox_idxs))
while len(non_regarded_bbox_idxs) > 0:
non_regarded_bbox_idxs_cp = np.copy(non_regarded_bbox_idxs)
for non_regarded_bbox_idx in non_regarded_bbox_idxs_cp:
Caffected_bbox_idxs = get_connected(non_regarded_bbox_idx, overlap_bbox_idxs)
affected_bbox_idxs = np.unique(np.append(affected_bbox_idxs, Caffected_bbox_idxs))
regarded_bbox_idxs.append(non_regarded_bbox_idx)
non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs))
bbox_idx_group = np.array(regarded_bbox_idxs)
bbox_scores = df_collect[bbox_idx_group][:, -1].astype(float)
bbox_groups[bbox_idx_group] = np.max(bbox_groups) + 1
remove_idx_combinations = [()]
remove_idx_combinations_scores = [0]
for r in range(1, len(bbox_idx_group)):
remove_idx_combinations.extend(list(itertools.combinations(bbox_idx_group, r=r)))
remove_idx_combinations_scores.extend(list(itertools.combinations(bbox_scores, r=r)))
for enu, combi_score in enumerate(remove_idx_combinations_scores):
remove_idx_combinations_scores[enu] = np.sum(combi_score)
if len(bbox_idx_group) > 1:
remove_idx_combinations = [remove_idx_combinations[ind] for ind in np.argsort(remove_idx_combinations_scores)]
remove_idx_combinations_scores = [remove_idx_combinations_scores[ind] for ind in np.argsort(remove_idx_combinations_scores)]
for remove_idx in remove_idx_combinations:
select_bbox_idx_group = list(set(bbox_idx_group) - set(remove_idx))
time_overlap_pct, freq_overlap_pct = (
compute_time_frequency_overlap_for_bbox_group(select_bbox_idx_group,df_collect))
if np.all(np.min([time_overlap_pct, freq_overlap_pct], axis=0) < overlap_th):
break
if len(remove_idx) > 0:
bbox_groups[np.array(remove_idx)] *= -1
handled_bbox_idxs.extend(bbox_idx_group)
return bbox_groups
def compute_time_frequency_overlap_for_bbox_group(bbox_idx_group, df_collect):
time_overlap_pct = np.zeros((len(bbox_idx_group), len(bbox_idx_group)))
freq_overlap_pct = np.zeros((len(bbox_idx_group), len(bbox_idx_group)))
for i, j in itertools.product(range(len(bbox_idx_group)), repeat=2):
if i == j:
continue
bb0_idx = bbox_idx_group[i]
bb1_idx = bbox_idx_group[j]
bb0_t0, bb0_t1 = df_collect[bb0_idx][1].astype(float), df_collect[bb0_idx][3].astype(float)
bb1_t0, bb1_t1 = df_collect[bb1_idx][1].astype(float), df_collect[bb1_idx][3].astype(float)
bb0_f0, bb0_f1 = df_collect[bb0_idx][2].astype(float), df_collect[bb0_idx][4].astype(float)
bb1_f0, bb1_f1 = df_collect[bb1_idx][2].astype(float), df_collect[bb1_idx][4].astype(float)
bb_times_idx = np.array([0, 0, 1, 1])
bb_times = np.array([bb0_t0, bb0_t1, bb1_t0, bb1_t1])
sorted_bb_times_idx = bb_times_idx[bb_times.argsort()]
if sorted_bb_times_idx[0] == sorted_bb_times_idx[1]:
time_overlap_pct[i, j] = 0
elif sorted_bb_times_idx[1] == sorted_bb_times_idx[2] == 0:
time_overlap_pct[i, j] = 1
elif sorted_bb_times_idx[1] == sorted_bb_times_idx[2] == 1:
time_overlap_pct[i, j] = (bb1_t1 - bb1_t0) / (bb0_t1 - bb0_t0)
else:
time_overlap_pct[i, j] = np.diff(sorted(bb_times)[1:3])[0] / ((bb0_t1 - bb0_t0))
bb_freqs_idx = np.array([0, 0, 1, 1])
bb_freqs = np.array([bb0_f0, bb0_f1, bb1_f0, bb1_f1])
sorted_bb_freqs_idx = bb_freqs_idx[bb_freqs.argsort()]
if sorted_bb_freqs_idx[0] == sorted_bb_freqs_idx[1]:
freq_overlap_pct[i, j] = 0
elif sorted_bb_freqs_idx[1] == sorted_bb_freqs_idx[2] == 0:
freq_overlap_pct[i, j] = 1
elif sorted_bb_freqs_idx[1] == sorted_bb_freqs_idx[2] == 1:
freq_overlap_pct[i, j] = (bb1_f1 - bb1_f0) / (bb0_f1 - bb0_f0)
else:
freq_overlap_pct[i, j] = np.diff(sorted(bb_freqs)[1:3])[0] / ((bb0_f1 - bb0_f0))
return time_overlap_pct, freq_overlap_pct
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Extract time, frequency and identity association of bboxes')
parser.add_argument('annotations', nargs='?', type=str, help='path to annotations')
parser.add_argument('-t', '--tracking_data_path', type=str, help='path to tracking dataa')
args = parser.parse_args()
main(args)