From 324c69084177daca389ab03a1de7b049011cc55f Mon Sep 17 00:00:00 2001 From: Till Raab Date: Thu, 16 Nov 2023 12:19:57 +0100 Subject: [PATCH] model and datasets adapted to new yolo format. reinforced learning is now implemented. correct_bboxes.py now capable of exporting data to train file --- confic.py | 4 +-- correct_bboxes.py | 62 ++++++++++++++++++++++---------------- datasets.py | 73 +++------------------------------------------ generate_dataset.py | 64 ++++++++------------------------------- inference.py | 24 +++++++++++++++ train.py | 5 ---- 6 files changed, 80 insertions(+), 152 deletions(-) diff --git a/confic.py b/confic.py index fbe7e1c..e3bd945 100644 --- a/confic.py +++ b/confic.py @@ -26,8 +26,8 @@ DELTA_TIME = 60*10 TIME_OVERLAP = 60*1 # output parameters -DATA_DIR = 'data/images' -LABEL_DIR = 'data/labels' +DATA_DIR = 'data/rise_training/images' +LABEL_DIR = 'data/rise_training/labels' OUTDIR = 'model_outputs' INFERENCE_OUTDIR = 'inference_outputs' for required_folders in [DATA_DIR, OUTDIR, INFERENCE_OUTDIR, LABEL_DIR]: diff --git a/correct_bboxes.py b/correct_bboxes.py index 0bf024d..5c67890 100644 --- a/correct_bboxes.py +++ b/correct_bboxes.py @@ -87,11 +87,13 @@ class Bbox_correct_UI(QMainWindow): self.data_path = data_path self.files = sorted(list(pathlib.Path(self.data_path).absolute().rglob('*images/*.png'))) + self.close_without_saving_rois = False + rec = QApplication.desktop().screenGeometry() height = rec.height() width = rec.width() self.resize(int(.8 * width), int(.8 * height)) - self.setWindowTitle('efishSignalTracker') # set window title + # widget and layout self.central_widget = QWidget(self) @@ -100,19 +102,23 @@ class Bbox_correct_UI(QMainWindow): self.central_widget.setLayout(self.central_Layout) self.setCentralWidget(self.central_widget) + ########### + self.highlighted_label = None + self.all_labels = [] + self.load_or_create_file_dict() + + new_name, new_file, new_checked = self.file_dict.iloc[0].values + self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}') + # image widget - self.current_img = ImageWithBbox(self.files[0], parent=self) + self.current_img = ImageWithBbox(new_file, parent=self) self.central_Layout.addWidget(self.current_img, 4) + # image select widget self.scroll = QScrollArea() # Scroll Area which contains the widgets, set as the centralWidget self.widget = QWidget() # Widget that contains the collection of Vertical Box self.vbox = QVBoxLayout() - self.highlighted_label = None - self.all_labels = [] - - self.load_or_create_file_dict() - for i in range(len(self.file_dict)): label = QLabel(f'{self.file_dict["name"][i]}') # label.setFrameShape(QLabel.Panel) @@ -198,13 +204,15 @@ class Bbox_correct_UI(QMainWindow): def switch_to_new_file(self, new_file, new_name): self.readout_rois() - self.setWindowTitle(f'efishSignalTracker | {new_name}') + self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}') self.central_Layout.removeWidget(self.current_img) self.current_img = ImageWithBbox(new_file, parent=self) self.central_Layout.insertWidget(0, self.current_img, 4) def readout_rois(self): + if self.close_without_saving_rois: + return new_labels = [] for roi in self.current_img.ROIs: x0, y0 = roi.pos() @@ -225,29 +233,33 @@ class Bbox_correct_UI(QMainWindow): fd = QFileDialog() export_path = fd.getExistingDirectory(self, 'Select Directory') if export_path: - embed() - quit() - # ToDo: copy the validated files and delete from csv - # ToDo: copy entries to new/exten csv - # only files that are not in the training dataset so far - print(export_path) - else: - print('nope') - pass + export_idxs = list(self.file_dict['files'][self.file_dict['checked'] == 1].index) + keep_idxs = [] + for export_file_path, export_idx in zip(list(self.file_dict['files'][self.file_dict['checked'] == 1]), export_idxs): + export_image_path = pathlib.Path(export_file_path) + export_label_path = export_image_path.parent.parent / 'labels' / pathlib.Path(export_image_path.name).with_suffix('.txt') + + target_image_path = pathlib.Path(export_path) / 'images' / export_image_path.name + target_label_path = pathlib.Path(export_path) / 'labels' / export_label_path.name + if not target_image_path.exists(): + os.rename(export_image_path, target_image_path) + os.rename(export_label_path, target_label_path) + else: + keep_idxs.append(export_idx) + + + # self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0] + drop_idxs = list(set(export_idxs) - set(keep_idxs)) + self.file_dict = self.file_dict.drop(drop_idxs) + self.save_file_dict() + self.close_without_saving_rois = True + self.close() - def import_data(self): - fd = QFileDialog() - import_path = fd.getExistingDirectory(self, 'Select Directory') - if import_path: - # ToDo: copy the UN validated files and add them to csv - # only files that are not in the current folder - print(import_path) else: print('nope') pass def open(self): - # ToDo: to be implemented pass def init_MenuBar(self): diff --git a/datasets.py b/datasets.py index ad5effd..128037a 100644 --- a/datasets.py +++ b/datasets.py @@ -26,6 +26,8 @@ class InferenceDataset(Dataset): def __init__(self, dir_path): self.dir_path = dir_path self.all_images = sorted(list(Path(self.dir_path).rglob(f'*.png'))) + self.file_names = np.array([Path(x).with_suffix('') for x in self.all_images]) + # self.images = np.array(sorted(os.listdir(DATA_DIR))) def __len__(self): return len(self.all_images) @@ -39,45 +41,6 @@ class InferenceDataset(Dataset): class CustomDataset(Dataset): - def __init__(self, dir_path, bbox_df): - self.dir_path = dir_path - self.bbox_df = bbox_df - - self.all_images = np.array(sorted(pd.unique(self.bbox_df['image'])), dtype=str) - self.image_paths = list(map(lambda x: Path(self.dir_path)/x, self.all_images)) - - def __getitem__(self, idx): - image_name = self.all_images[idx] - image_path = os.path.join(self.dir_path, image_name) - - img = Image.open(image_path) - img_tensor = F.to_tensor(img.convert('RGB')) - - Cbbox = self.bbox_df[self.bbox_df['image'] == image_name] - - labels = np.ones(len(Cbbox), dtype=int) - boxes = torch.as_tensor(Cbbox.loc[:, ['x0', 'y0', 'x1', 'y1']].values, dtype=torch.float32) - - area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) - # no crowd instances - iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64) - - target = {} - target["boxes"] = boxes - target["labels"] = torch.as_tensor(labels, dtype=torch.int64) - target["area"] = area - target["iscrowd"] = iscrowd - image_id = torch.tensor([idx]) - target["image_id"] = image_id - target["image_name"] = image_name #ToDo: implement this as 3rd return... - - return img_tensor, target - - def __len__(self): - return len(self.all_images) - - -class CustomDataset_v2(Dataset): def __init__(self, limited_idxs=None): self.images = np.array(sorted(os.listdir(DATA_DIR))) self.labels = np.array(sorted(os.listdir(LABEL_DIR))) @@ -99,7 +62,6 @@ class CustomDataset_v2(Dataset): boxes, labels, area, iscrowd = self.extract_bboxes(annotations) - target = {} target["boxes"] = boxes target["labels"] = torch.as_tensor(labels, dtype=torch.int64) @@ -147,27 +109,11 @@ def custom_train_test_split(): train_idxs = np.sort(data_idxs[int(0.2 * len(data_idxs)):]) test_idxs = np.sort(data_idxs[:int(0.2 * len(data_idxs))]) - train_data = CustomDataset_v2(limited_idxs=train_idxs) - test_data = CustomDataset_v2(limited_idxs=test_idxs) + train_data = CustomDataset(limited_idxs=train_idxs) + test_data = CustomDataset(limited_idxs=test_idxs) return train_data, test_data -def create_train_or_test_dataset(path, train=True): - if train == True: - pfx='train' - print('Generate train dataset !') - else: - print('Generate test dataset !') - pfx='test' - - csv_candidates = list(Path(path).rglob(f'*{pfx}*.csv')) - if len(csv_candidates) == 0: - print(f'no .csv files for *{pfx}* found in {Path(path)}') - quit() - else: - bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0) - return CustomDataset(path, bboxes) - def create_train_loader(train_dataset, num_workers=0): train_loader = DataLoader( @@ -191,17 +137,6 @@ def create_valid_loader(valid_dataset, num_workers=0): ) return valid_loader -def create_inference_loader(inference_dataset, num_workers=0): - inference_loader = DataLoader( - inference_dataset, - batch_size=BATCH_SIZE, - shuffle=False, - num_workers=num_workers, - collate_fn=collate_fn - ) - return inference_loader - - if __name__ == '__main__': train_data, test_data = custom_train_test_split() diff --git a/generate_dataset.py b/generate_dataset.py index 6f282b7..738865c 100644 --- a/generate_dataset.py +++ b/generate_dataset.py @@ -97,7 +97,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, return fig_title -def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1): +def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1, label_save_folder): times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) @@ -185,13 +185,16 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq all_height ]).T - np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style) + np.savetxt(label_save_folder / Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style) return bbox_yolo_style def main(args): folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy')) - pic_save_folder = DATA_DIR if not args.inference else (Path('data') / Path(args.folder).name) + pic_save_folder = Path('data') / Path(args.folder).name / 'images' + label_save_folder = Path('data') / Path(args.folder).name / 'labels' + # embed() + # quit() if len(folders) == 0: print('no datasets containing fill_times.npy found') @@ -202,8 +205,10 @@ def main(args): else: print('generate inference dataset ... only image output') - if not (Path('data') / Path(args.folder).name).exists(): - (Path('data') / Path(args.folder).name).mkdir(parents=True, exist_ok=True) + if not pic_save_folder.exists(): + pic_save_folder.mkdir(parents=True, exist_ok=True) + if not label_save_folder.exists(): + label_save_folder.mkdir(parents=True, exist_ok=True) for enu, folder in enumerate(folders): print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') @@ -253,52 +258,7 @@ def main(args): if not args.inference: bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, - pic_save_str,t0, t1, f0, f1) - - ####################################################################### - # if False: - # if bbox_yolo_style.shape[0] >= 1: - # f_res, t_res = freq[1] - freq[0], times[1] - times[0] - # - # fig_title = ( - # f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace( - # ' ', '0') - # fig = plt.figure(figsize=IMG_SIZE, num=fig_title) - # gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95) # - # ax = fig.add_subplot(gs[0, 0]) - # ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', - # extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res)) - # # ax.invert_yaxis() - # # ax.axis(False) - # - # for i in range(len(bbox_df)): - # # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32) - # Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']] - # ax.add_patch( - # Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])), - # float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600, - # float(Cbbox['f1']) - float(Cbbox['f0']), - # fill=False, color="white", linestyle='-', linewidth=2, zorder=10) - # ) - # - # # print(bbox_yolo_style.T) - # - # for bbox in bbox_yolo_style: - # x0 = bbox[1] - bbox[3]/2 # x_center - width/2 - # y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2 - # w = bbox[3] - # h = bbox[4] - # ax.add_patch( - # Rectangle((x0, y0), w, h, - # fill=False, color="k", linestyle='--', linewidth=2, zorder=10, - # transform=ax.transAxes) - # ) - # plt.show() - ####################################################################### - - # if not args.inference: - # print('save bboxes') - # bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',') + pic_save_str,t0, t1, f0, f1, label_save_folder) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') @@ -306,4 +266,6 @@ if __name__ == '__main__': parser.add_argument('-i', "--inference", action="store_true") args = parser.parse_args() + # ToDo: put "images" in image folder + main(args) \ No newline at end of file diff --git a/inference.py b/inference.py index 326398c..d3bb410 100644 --- a/inference.py +++ b/inference.py @@ -54,11 +54,35 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8): with torch.inference_mode(): outputs = model(images) + # ToDo: save outputs in label folder ! for image, img_name, output in zip(images, img_names, outputs): + # x0, y0, x1, y1 + + yolo_labels = [] + # for (x0, y0, x1, y1) in output['boxes'].cpu().numpy(): + for Cbbox, score in zip(output['boxes'].cpu().numpy(), output['scores'].cpu().numpy()): + if score < detection_th: + continue + rel_x0 = Cbbox[0] / image.shape[-2] + rel_y0 = Cbbox[1] / image.shape[-2] + rel_x1 = Cbbox[2] / image.shape[-2] + rel_y1 = Cbbox[3] / image.shape[-2] + + rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2 + rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2 + rel_width = rel_x1 - rel_x0 + rel_height = rel_y1 - rel_y0 + + yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height]) + + label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt') + np.savetxt(label_path, yolo_labels) + plot_inference(image, img_name, output, detection_th, dataset_name) + def main(args): model = create_model(num_classes=NUM_CLASSES) checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE) diff --git a/train.py b/train.py index 38b93a8..6ceb49d 100644 --- a/train.py +++ b/train.py @@ -121,11 +121,6 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold): # plt.show() if __name__ == '__main__': - # train_data = create_train_or_test_dataset(DATA_DIR) - # test_data = create_train_or_test_dataset(DATA_DIR, train=False) - # - # train_loader = create_train_loader(train_data) - # test_loader = create_valid_loader(test_data) train_data, test_data = custom_train_test_split()