transfere bboxes to time-frequency points
This commit is contained in:
		
							parent
							
								
									0f0128439e
								
							
						
					
					
						commit
						93269a96a1
					
				
							
								
								
									
										103
									
								
								extract_from_bbox.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								extract_from_bbox.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,103 @@ | |||||||
|  | import pathlib | ||||||
|  | import argparse | ||||||
|  | import numpy as np | ||||||
|  | import pandas as pd | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | from matplotlib.patches import Rectangle | ||||||
|  | from PIL import Image | ||||||
|  | from tqdm.auto import tqdm | ||||||
|  | 
 | ||||||
|  | from IPython import embed | ||||||
|  | 
 | ||||||
|  | def extract_time_freq_range_from_filename(img_path): | ||||||
|  |     file_name_str, time_span_str, freq_span_str = str(img_path.with_suffix('').name).split('__') | ||||||
|  |     time_span_str = time_span_str.replace('s', '') | ||||||
|  |     freq_span_str = freq_span_str.replace('Hz', '') | ||||||
|  | 
 | ||||||
|  |     t0, t1 = np.array(time_span_str.split('-'), dtype=float) | ||||||
|  |     f0, f1 = np.array(freq_span_str.split('-'), dtype=float) | ||||||
|  |     return file_name_str, t0, t1, f0, f1 | ||||||
|  | 
 | ||||||
|  | def bbox_to_data(img_path, t_min, t_max, f_min, f_max): | ||||||
|  |     label_path = img_path.parent.parent / 'labels' / img_path.with_suffix('.txt').name | ||||||
|  | 
 | ||||||
|  |     annotations = np.loadtxt(label_path, delimiter=' ') | ||||||
|  |     if len(annotations.shape) == 1: | ||||||
|  |         annotations = np.array([annotations]) | ||||||
|  |     if annotations.shape[1] == 0: | ||||||
|  |         print('no rises detected in this window') | ||||||
|  |         return [] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     boxes = np.array([[x[1] - x[3] / 2, 1 - (x[2] + x[4] / 2), x[1] + x[3] / 2, 1 - (x[2] - x[4] / 2)] for x in annotations]) # x0, y0, x1, y1 | ||||||
|  |     boxes[:, 0] = boxes[:, 0] * (t_max - t_min) + t_min | ||||||
|  |     boxes[:, 2] = boxes[:, 2] * (t_max - t_min) + t_min | ||||||
|  | 
 | ||||||
|  |     boxes[:, 1] = boxes[:, 1] * (f_max - f_min) + f_min | ||||||
|  |     boxes[:, 3] = boxes[:, 3] * (f_max - f_min) + f_min | ||||||
|  | 
 | ||||||
|  |     return boxes | ||||||
|  | 
 | ||||||
|  | def load_wavetracker_data(raw_path): | ||||||
|  |     fund_v = np.load(raw_path.parent / 'fund_v.npy') | ||||||
|  |     ident_v = np.load(raw_path.parent / 'ident_v.npy') | ||||||
|  |     idx_v = np.load(raw_path.parent / 'idx_v.npy') | ||||||
|  |     times = np.load(raw_path.parent / 'times.npy') | ||||||
|  | 
 | ||||||
|  |     return fund_v, ident_v, idx_v, times | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def assign_rises_to_ids(raw_path, time_frequency_bboxes): | ||||||
|  |     fund_v, ident_v, idx_v, times = load_wavetracker_data(raw_path) | ||||||
|  | 
 | ||||||
|  |     fig, ax = plt.subplots() | ||||||
|  |     ax.plot(times[idx_v[~np.isnan(ident_v)]], fund_v[~np.isnan(ident_v)], '.') | ||||||
|  | 
 | ||||||
|  |     mask = time_frequency_bboxes['file_name'] == raw_path.parent.name | ||||||
|  |     for index, bbox in time_frequency_bboxes[mask].iterrows(): | ||||||
|  |         name, t0, f0, t1, f1 = (bbox[0], *bbox[1:].astype(float)) | ||||||
|  |         ax.add_patch( | ||||||
|  |             Rectangle((t0, f0), | ||||||
|  |                       (t1 - t0), | ||||||
|  |                       (f1 - f0), | ||||||
|  |                       fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10) | ||||||
|  |         ) | ||||||
|  |     plt.show() | ||||||
|  | 
 | ||||||
|  |     # ToDo: eliminate double rises -- overlap | ||||||
|  |     # ToDo: double detections -- non overlap --> the one with higher probability ?! | ||||||
|  |     # ToDo: assign rises to traces --> who is at lower right corner | ||||||
|  | 
 | ||||||
|  | def main(args): | ||||||
|  |     img_paths = sorted(list(pathlib.Path(args.annotations).absolute().rglob('*.png'))) | ||||||
|  |     df_collect = [] | ||||||
|  | 
 | ||||||
|  |     for img_path in img_paths: | ||||||
|  |         # convert to time_frequency | ||||||
|  |         file_name_str, t_min, t_max, f_min, f_max = extract_time_freq_range_from_filename(img_path) | ||||||
|  | 
 | ||||||
|  |         boxes = bbox_to_data(img_path, t_min, t_max, f_min, f_max ) # t0, t1, f0, f1 | ||||||
|  | 
 | ||||||
|  |         # store values in df | ||||||
|  |         if not len(boxes) == 0: | ||||||
|  |             for x0, y0, x1, y1 in boxes: | ||||||
|  |                 df_collect.append([file_name_str, x0, y0, x1, y1]) | ||||||
|  | 
 | ||||||
|  |     time_frequency_bboxes = pd.DataFrame(data= np.array(df_collect), columns=['file_name', 't0', 'f0', 't1', 'f1']) | ||||||
|  | 
 | ||||||
|  |     if args.tracking_data_path: | ||||||
|  |         file_paths = sorted(list(pathlib.Path(args.tracking_data_path).absolute().rglob('*.raw'))) | ||||||
|  |         for raw_path in file_paths: | ||||||
|  |             if not raw_path.parent.name in time_frequency_bboxes['file_name'].to_list(): | ||||||
|  |                 continue | ||||||
|  | 
 | ||||||
|  |             assign_rises_to_ids(raw_path, time_frequency_bboxes) | ||||||
|  | 
 | ||||||
|  |     pass | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     parser = argparse.ArgumentParser(description='Extract time, frequency and identity association of bboxes') | ||||||
|  |     parser.add_argument('annotations', nargs='?', type=str, help='path to annotations') | ||||||
|  |     parser.add_argument('-t', '--tracking_data_path', type=str, help='path to tracking dataa') | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     main(args) | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user