From f70e74f5e198d1ff864203285aa6f172cc98d19d Mon Sep 17 00:00:00 2001 From: Till Raab Date: Tue, 28 Nov 2023 14:53:37 +0100 Subject: [PATCH] inference.py now also writes bbox scores; extract_from_bbox.py detects overlapping bbox groups and eliminates them based on score untill all bboxes have overlap below th... detection of groups is done using time as parameter only. implement also frequency. then sort them as stated previously --- extract_from_bbox.py | 157 +++++++++++++++++++++++++++++++++++++++---- inference.py | 2 +- 2 files changed, 146 insertions(+), 13 deletions(-) diff --git a/extract_from_bbox.py b/extract_from_bbox.py index eae2c2b..92715ae 100644 --- a/extract_from_bbox.py +++ b/extract_from_bbox.py @@ -1,3 +1,4 @@ +import itertools import pathlib import argparse import numpy as np @@ -26,7 +27,7 @@ def bbox_to_data(img_path, t_min, t_max, f_min, f_max): annotations = np.array([annotations]) if annotations.shape[1] == 0: print('no rises detected in this window') - return [] + return [], [] boxes = np.array([[x[1] - x[3] / 2, 1 - (x[2] + x[4] / 2), x[1] + x[3] / 2, 1 - (x[2] - x[4] / 2)] for x in annotations]) # x0, y0, x1, y1 @@ -36,7 +37,9 @@ def bbox_to_data(img_path, t_min, t_max, f_min, f_max): boxes[:, 1] = boxes[:, 1] * (f_max - f_min) + f_min boxes[:, 3] = boxes[:, 3] * (f_max - f_min) + f_min - return boxes + scores = annotations[:, -1] + + return boxes, scores def load_wavetracker_data(raw_path): fund_v = np.load(raw_path.parent / 'fund_v.npy') @@ -47,7 +50,7 @@ def load_wavetracker_data(raw_path): return fund_v, ident_v, idx_v, times -def assign_rises_to_ids(raw_path, time_frequency_bboxes): +def assign_rises_to_ids(raw_path, time_frequency_bboxes, overlapping_boxes, bbox_groups): fund_v, ident_v, idx_v, times = load_wavetracker_data(raw_path) fig, ax = plt.subplots() @@ -55,19 +58,75 @@ def assign_rises_to_ids(raw_path, time_frequency_bboxes): mask = time_frequency_bboxes['file_name'] == raw_path.parent.name for index, bbox in time_frequency_bboxes[mask].iterrows(): - name, t0, f0, t1, f1 = (bbox[0], *bbox[1:].astype(float)) + name, t0, f0, t1, f1 = (bbox[0], *bbox[1:-1].astype(float)) + if bbox_groups[index] == 0: + color = 'tab:green' + else: + color = 'k' + # if overlapping_boxes[index] == 0: + # color='tab:green' + # elif overlapping_boxes[index] == 1: + # color = 'tab:olive' + # elif overlapping_boxes[index] == 2: + # color = 'tab:orange' + # elif overlapping_boxes[index] == 3: + # color = 'tab:red' + # color = 'tab:green' if overlapping_boxes[index] == 0 else 'tab:orange' ax.add_patch( Rectangle((t0, f0), (t1 - t0), (f1 - f0), - fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10) + fill=False, color=color, linestyle='--', linewidth=2, zorder=10) ) plt.show() - + # if np.any(overlapping_boxes[mask] >= 2): + # print('yay') + # embed() + # quit() # ToDo: eliminate double rises -- overlap # ToDo: double detections -- non overlap --> the one with higher probability ?! # ToDo: assign rises to traces --> who is at lower right corner +def find_overlapping_bboxes(df_collect): + file_names = np.array(df_collect)[:, 0] + bboxes_overlapping_mask = np.zeros(len(df_collect)) + bboxes = np.array(df_collect)[:, 1:].astype(float) + + overlap_bbox_idxs = [] + + for file_name in tqdm(np.unique(file_names)): + file_bbox_idxs = np.arange(len(file_names))[file_names == file_name] + for ind0, ind1 in itertools.combinations(file_bbox_idxs, r=2): + bb0 = bboxes[ind0] + bb1 = bboxes[ind1] + t0_0, f0_0, t0_1, f0_1 = bb0[:-1] + t1_0, f1_0, t1_1, f1_1 = bb1[:-1] + + bb_times = np.array([t0_0, t0_1, t1_0, t1_1]) + bb_time_associate = np.array([0, 0, 1, 1]) + time_helper = bb_time_associate[np.argsort(bb_times)] + + if time_helper[0] == time_helper[1]: + # no temporal overlap + continue + + # check freq overlap + bb_freqs = np.array([f0_0, f0_1, f1_0, f1_1]) + bb_freq_associate = np.array([0, 0, 1, 1]) + + freq_helper = bb_freq_associate[np.argsort(bb_freqs)] + + if freq_helper[0] == freq_helper[1]: + continue + + bboxes_overlapping_mask[ind0] +=1 + bboxes_overlapping_mask[ind1] +=1 + + overlap_bbox_idxs.append((ind0, ind1)) + + return bboxes_overlapping_mask, np.asarray(overlap_bbox_idxs) + + def main(args): img_paths = sorted(list(pathlib.Path(args.annotations).absolute().rglob('*.png'))) df_collect = [] @@ -76,14 +135,20 @@ def main(args): # convert to time_frequency file_name_str, t_min, t_max, f_min, f_max = extract_time_freq_range_from_filename(img_path) - boxes = bbox_to_data(img_path, t_min, t_max, f_min, f_max ) # t0, t1, f0, f1 - + boxes, scores = bbox_to_data(img_path, t_min, t_max, f_min, f_max ) # t0, t1, f0, f1 # store values in df if not len(boxes) == 0: - for x0, y0, x1, y1 in boxes: - df_collect.append([file_name_str, x0, y0, x1, y1]) + for (t0, f0, t1, f1), s in zip(boxes, scores): + df_collect.append([file_name_str, t0, f0, t1, f1, s]) + + df_collect = np.array(df_collect) - time_frequency_bboxes = pd.DataFrame(data= np.array(df_collect), columns=['file_name', 't0', 'f0', 't1', 'f1']) + bbox_overlapping_mask, overlap_bbox_idxs = find_overlapping_bboxes(df_collect) + + bbox_groups = delete_double_boxes(bbox_overlapping_mask, overlap_bbox_idxs, df_collect) + # embed() + # quit() + time_frequency_bboxes = pd.DataFrame(data= np.array(df_collect), columns=['file_name', 't0', 'f0', 't1', 'f1', 'score']) if args.tracking_data_path: file_paths = sorted(list(pathlib.Path(args.tracking_data_path).absolute().rglob('*.raw'))) @@ -91,10 +156,78 @@ def main(args): if not raw_path.parent.name in time_frequency_bboxes['file_name'].to_list(): continue - assign_rises_to_ids(raw_path, time_frequency_bboxes) + assign_rises_to_ids(raw_path, time_frequency_bboxes, bbox_overlapping_mask, bbox_groups) pass +def delete_double_boxes(bbox_overlapping_mask, overlap_bbox_idxs, df_collect): + def get_connected(non_regarded_bbox_idx, overlap_bbox_idxs): + mask = np.array((np.array(overlap_bbox_idxs) == non_regarded_bbox_idx).sum(1), dtype=bool) + affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask]) + return affected_bbox_idxs + + handled_bbox_idxs = [] + bbox_groups = np.zeros(len(df_collect)) + for Coverlapping_bbox_idx in tqdm(np.unique(overlap_bbox_idxs)): + if Coverlapping_bbox_idx in handled_bbox_idxs: + continue + # if bbox_overlapping_mask[Coverlapping_bbox_idx] >= 3: + # pass + # else: + # continue + + regarded_bbox_idxs = [Coverlapping_bbox_idx] + mask = np.array((np.array(overlap_bbox_idxs) == Coverlapping_bbox_idx).sum(1), dtype=bool) + affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask]) + + non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs)) + # non_regarded_bbox_idxs = list(set(non_regarded_bbox_idxs) - set(handled_bbox_idxs)) + while len(non_regarded_bbox_idxs) > 0: + non_regarded_bbox_idxs_cp = np.copy(non_regarded_bbox_idxs) + for non_regarded_bbox_idx in non_regarded_bbox_idxs_cp: + Caffected_bbox_idxs = get_connected(non_regarded_bbox_idx, overlap_bbox_idxs) + affected_bbox_idxs = np.unique(np.append(affected_bbox_idxs, Caffected_bbox_idxs)) + + regarded_bbox_idxs.append(non_regarded_bbox_idx) + non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs)) + + bbox_idx_group = regarded_bbox_idxs + bbox_groups[bbox_idx_group] = np.max(bbox_groups) + 1 + # bbox_scores = df_collect[bbox_idx_group][:, -1] + # overlap_pct = np.full((len(bbox_idx_group), len(bbox_idx_group)), np.nan) + # + # for i, j in itertools.product(range(len(bbox_idx_group)), repeat=2): + # if i == j: + # continue + # bb0_idx = bbox_idx_group[i] + # bb1_idx = bbox_idx_group[j] + # + # bb0_t0, bb0_t1 = df_collect[bb0_idx][1].astype(float), df_collect[bb0_idx][3].astype(float) + # bb1_t0, bb1_t1 = df_collect[bb1_idx][1].astype(float), df_collect[bb1_idx][3].astype(float) + # + # helper = np.array([0, 0, 1, 1]) + # bb_times = np.array([bb0_t0, bb0_t1, bb1_t0, bb1_t1]) + # + # sorted_helper = helper[bb_times.argsort()] + # + # if sorted_helper[0] == sorted_helper[1]: + # continue + # + # elif sorted_helper[1] == sorted_helper[2] == 0: + # overlap_pct[i, j] = 1 + # + # elif sorted_helper[1] == sorted_helper[2] == 1: + # overlap_pct[i, j] = (bb1_t1 - bb1_t0) / (bb0_t1 - bb0_t0) + # + # else: + # overlap_pct[i, j] = np.diff(sorted(bb_times)[1:3])[0] / ((bb0_t1 - bb0_t0)) + # embed() + # quit() + + handled_bbox_idxs.extend(bbox_idx_group) + + return bbox_groups + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Extract time, frequency and identity association of bboxes') parser.add_argument('annotations', nargs='?', type=str, help='path to annotations') diff --git a/inference.py b/inference.py index 8bc83e5..b9d29fd 100644 --- a/inference.py +++ b/inference.py @@ -73,7 +73,7 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8): rel_width = rel_x1 - rel_x0 rel_height = rel_y1 - rel_y0 - yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height]) + yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height, score]) label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt') np.savetxt(label_path, yolo_labels)