diff --git a/extract_from_bbox.py b/extract_from_bbox.py index 92715ae..cde4719 100644 --- a/extract_from_bbox.py +++ b/extract_from_bbox.py @@ -37,7 +37,7 @@ def bbox_to_data(img_path, t_min, t_max, f_min, f_max): boxes[:, 1] = boxes[:, 1] * (f_max - f_min) + f_min boxes[:, 3] = boxes[:, 3] * (f_max - f_min) + f_min - scores = annotations[:, -1] + scores = annotations[:, 5] return boxes, scores @@ -146,21 +146,49 @@ def main(args): bbox_overlapping_mask, overlap_bbox_idxs = find_overlapping_bboxes(df_collect) bbox_groups = delete_double_boxes(bbox_overlapping_mask, overlap_bbox_idxs, df_collect) - # embed() - # quit() + print('got here') + time_frequency_bboxes = pd.DataFrame(data= np.array(df_collect), columns=['file_name', 't0', 'f0', 't1', 'f1', 'score']) + ########################################### + colors = np.random.rand(np.max(bbox_groups).astype(int), 3) + for file_name in time_frequency_bboxes['file_name'].unique(): + fig, ax = plt.subplots() + + mask = time_frequency_bboxes['file_name'] == file_name + for index, bbox in time_frequency_bboxes[mask].iterrows(): + name, t0, f0, t1, f1 = (bbox[0], *bbox[1:-1].astype(float)) + if bbox_groups[index] == 0: + color = 'tab:green' + elif bbox_groups[index] > 0: + color = 'tab:red' + else: + color = 'k' + + ax.add_patch( + Rectangle((t0, f0), + (t1 - t0), + (f1 - f0), + fill=False, color=color, linestyle='--', linewidth=2, zorder=10) + ) + # ax.set_xlim(float(time_frequency_bboxes[mask]['t0'].min()), float(time_frequency_bboxes[mask]['t1'].max())) + ax.set_xlim(0, float(time_frequency_bboxes[mask]['t1'].max())) + # ax.set_ylim(float(time_frequency_bboxes[mask]['f0'].min()), float(time_frequency_bboxes[mask]['f1'].max())) + ax.set_ylim(400, 1200) + plt.show() + exit() + ########################################### + if args.tracking_data_path: file_paths = sorted(list(pathlib.Path(args.tracking_data_path).absolute().rglob('*.raw'))) for raw_path in file_paths: if not raw_path.parent.name in time_frequency_bboxes['file_name'].to_list(): continue - assign_rises_to_ids(raw_path, time_frequency_bboxes, bbox_overlapping_mask, bbox_groups) pass -def delete_double_boxes(bbox_overlapping_mask, overlap_bbox_idxs, df_collect): +def delete_double_boxes(bbox_overlapping_mask, overlap_bbox_idxs, df_collect, overlap_th = 0.2): def get_connected(non_regarded_bbox_idx, overlap_bbox_idxs): mask = np.array((np.array(overlap_bbox_idxs) == non_regarded_bbox_idx).sum(1), dtype=bool) affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask]) @@ -168,13 +196,11 @@ def delete_double_boxes(bbox_overlapping_mask, overlap_bbox_idxs, df_collect): handled_bbox_idxs = [] bbox_groups = np.zeros(len(df_collect)) + detele_bbox_idxs = [] + for Coverlapping_bbox_idx in tqdm(np.unique(overlap_bbox_idxs)): if Coverlapping_bbox_idx in handled_bbox_idxs: continue - # if bbox_overlapping_mask[Coverlapping_bbox_idx] >= 3: - # pass - # else: - # continue regarded_bbox_idxs = [Coverlapping_bbox_idx] mask = np.array((np.array(overlap_bbox_idxs) == Coverlapping_bbox_idx).sum(1), dtype=bool) @@ -191,43 +217,96 @@ def delete_double_boxes(bbox_overlapping_mask, overlap_bbox_idxs, df_collect): regarded_bbox_idxs.append(non_regarded_bbox_idx) non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs)) - bbox_idx_group = regarded_bbox_idxs + bbox_idx_group = np.array(regarded_bbox_idxs) + bbox_scores = df_collect[bbox_idx_group][:, -1].astype(float) + # bbox_idx_group = bbox_idx_group[bbox_scores.argsort()] + # bbox_scores = bbox_scores[bbox_scores.argsort()] + bbox_groups[bbox_idx_group] = np.max(bbox_groups) + 1 - # bbox_scores = df_collect[bbox_idx_group][:, -1] - # overlap_pct = np.full((len(bbox_idx_group), len(bbox_idx_group)), np.nan) - # - # for i, j in itertools.product(range(len(bbox_idx_group)), repeat=2): - # if i == j: - # continue - # bb0_idx = bbox_idx_group[i] - # bb1_idx = bbox_idx_group[j] - # - # bb0_t0, bb0_t1 = df_collect[bb0_idx][1].astype(float), df_collect[bb0_idx][3].astype(float) - # bb1_t0, bb1_t1 = df_collect[bb1_idx][1].astype(float), df_collect[bb1_idx][3].astype(float) - # - # helper = np.array([0, 0, 1, 1]) - # bb_times = np.array([bb0_t0, bb0_t1, bb1_t0, bb1_t1]) - # - # sorted_helper = helper[bb_times.argsort()] - # - # if sorted_helper[0] == sorted_helper[1]: - # continue - # - # elif sorted_helper[1] == sorted_helper[2] == 0: - # overlap_pct[i, j] = 1 - # - # elif sorted_helper[1] == sorted_helper[2] == 1: - # overlap_pct[i, j] = (bb1_t1 - bb1_t0) / (bb0_t1 - bb0_t0) - # - # else: - # overlap_pct[i, j] = np.diff(sorted(bb_times)[1:3])[0] / ((bb0_t1 - bb0_t0)) - # embed() - # quit() + + remove_idx_combinations = [()] + remove_idx_combinations_scores = [0] + for r in range(1, len(bbox_idx_group)): + remove_idx_combinations.extend(list(itertools.combinations(bbox_idx_group, r=r))) + remove_idx_combinations_scores.extend(list(itertools.combinations(bbox_scores, r=r))) + for enu, combi_score in enumerate(remove_idx_combinations_scores): + remove_idx_combinations_scores[enu] = np.sum(combi_score) + if len(bbox_idx_group) > 2: + print(remove_idx_combinations) + print(remove_idx_combinations_scores) + print('') + + remove_idx_combinations = [remove_idx_combinations[ind] for ind in np.argsort(remove_idx_combinations_scores)] + remove_idx_combinations_scores = [remove_idx_combinations_scores[ind] for ind in np.argsort(remove_idx_combinations_scores)] + + print(remove_idx_combinations) + print(remove_idx_combinations_scores) + print('') + + # time_overlap_pct, freq_overlap_pct = compute_time_frequency_overlap_for_bbox_group(bbox_idx_group, df_collect) + + for remove_idx in remove_idx_combinations: + select_bbox_idx_group = list(set(bbox_idx_group) - set(remove_idx)) + time_overlap_pct, freq_overlap_pct = ( + compute_time_frequency_overlap_for_bbox_group(select_bbox_idx_group,df_collect)) + + if np.all(np.min([time_overlap_pct, freq_overlap_pct], axis=0) < overlap_th): + break + #embed() + #quit() + if len(remove_idx) > 0: + bbox_groups[np.array(remove_idx)] *= -1 handled_bbox_idxs.extend(bbox_idx_group) return bbox_groups +def compute_time_frequency_overlap_for_bbox_group(bbox_idx_group, df_collect): + time_overlap_pct = np.zeros((len(bbox_idx_group), len(bbox_idx_group))) + freq_overlap_pct = np.zeros((len(bbox_idx_group), len(bbox_idx_group))) + + + for i, j in itertools.product(range(len(bbox_idx_group)), repeat=2): + if i == j: + continue + bb0_idx = bbox_idx_group[i] + bb1_idx = bbox_idx_group[j] + + bb0_t0, bb0_t1 = df_collect[bb0_idx][1].astype(float), df_collect[bb0_idx][3].astype(float) + bb1_t0, bb1_t1 = df_collect[bb1_idx][1].astype(float), df_collect[bb1_idx][3].astype(float) + + bb0_f0, bb0_f1 = df_collect[bb0_idx][2].astype(float), df_collect[bb0_idx][4].astype(float) + bb1_f0, bb1_f1 = df_collect[bb1_idx][2].astype(float), df_collect[bb1_idx][4].astype(float) + + bb_times_idx = np.array([0, 0, 1, 1]) + bb_times = np.array([bb0_t0, bb0_t1, bb1_t0, bb1_t1]) + sorted_bb_times_idx = bb_times_idx[bb_times.argsort()] + + if sorted_bb_times_idx[0] == sorted_bb_times_idx[1]: + time_overlap_pct[i, j] = 0 + elif sorted_bb_times_idx[1] == sorted_bb_times_idx[2] == 0: + time_overlap_pct[i, j] = 1 + elif sorted_bb_times_idx[1] == sorted_bb_times_idx[2] == 1: + time_overlap_pct[i, j] = (bb1_t1 - bb1_t0) / (bb0_t1 - bb0_t0) + else: + time_overlap_pct[i, j] = np.diff(sorted(bb_times)[1:3])[0] / ((bb0_t1 - bb0_t0)) + + bb_freqs_idx = np.array([0, 0, 1, 1]) + bb_freqs = np.array([bb0_f0, bb0_f1, bb1_f0, bb1_f1]) + sorted_bb_freqs_idx = bb_freqs_idx[bb_freqs.argsort()] + + if sorted_bb_freqs_idx[0] == sorted_bb_freqs_idx[1]: + freq_overlap_pct[i, j] = 0 + elif sorted_bb_freqs_idx[1] == sorted_bb_freqs_idx[2] == 0: + freq_overlap_pct[i, j] = 1 + elif sorted_bb_freqs_idx[1] == sorted_bb_freqs_idx[2] == 1: + freq_overlap_pct[i, j] = (bb1_f1 - bb1_f0) / (bb0_f1 - bb0_f0) + else: + freq_overlap_pct[i, j] = np.diff(sorted(bb_freqs)[1:3])[0] / ((bb0_f1 - bb0_f0)) + + return time_overlap_pct, freq_overlap_pct + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Extract time, frequency and identity association of bboxes') parser.add_argument('annotations', nargs='?', type=str, help='path to annotations')