model and datasets adapted to new yolo format. reinforced learning is now implemented. correct_bboxes.py now capable of exporting data to train file
This commit is contained in:
parent
3ed6f1102c
commit
324c690841
@ -26,8 +26,8 @@ DELTA_TIME = 60*10
|
||||
TIME_OVERLAP = 60*1
|
||||
|
||||
# output parameters
|
||||
DATA_DIR = 'data/images'
|
||||
LABEL_DIR = 'data/labels'
|
||||
DATA_DIR = 'data/rise_training/images'
|
||||
LABEL_DIR = 'data/rise_training/labels'
|
||||
OUTDIR = 'model_outputs'
|
||||
INFERENCE_OUTDIR = 'inference_outputs'
|
||||
for required_folders in [DATA_DIR, OUTDIR, INFERENCE_OUTDIR, LABEL_DIR]:
|
||||
|
@ -87,11 +87,13 @@ class Bbox_correct_UI(QMainWindow):
|
||||
self.data_path = data_path
|
||||
self.files = sorted(list(pathlib.Path(self.data_path).absolute().rglob('*images/*.png')))
|
||||
|
||||
self.close_without_saving_rois = False
|
||||
|
||||
rec = QApplication.desktop().screenGeometry()
|
||||
height = rec.height()
|
||||
width = rec.width()
|
||||
self.resize(int(.8 * width), int(.8 * height))
|
||||
self.setWindowTitle('efishSignalTracker') # set window title
|
||||
|
||||
|
||||
# widget and layout
|
||||
self.central_widget = QWidget(self)
|
||||
@ -100,19 +102,23 @@ class Bbox_correct_UI(QMainWindow):
|
||||
self.central_widget.setLayout(self.central_Layout)
|
||||
self.setCentralWidget(self.central_widget)
|
||||
|
||||
###########
|
||||
self.highlighted_label = None
|
||||
self.all_labels = []
|
||||
self.load_or_create_file_dict()
|
||||
|
||||
new_name, new_file, new_checked = self.file_dict.iloc[0].values
|
||||
self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')
|
||||
|
||||
# image widget
|
||||
self.current_img = ImageWithBbox(self.files[0], parent=self)
|
||||
self.current_img = ImageWithBbox(new_file, parent=self)
|
||||
self.central_Layout.addWidget(self.current_img, 4)
|
||||
|
||||
# image select widget
|
||||
self.scroll = QScrollArea() # Scroll Area which contains the widgets, set as the centralWidget
|
||||
self.widget = QWidget() # Widget that contains the collection of Vertical Box
|
||||
self.vbox = QVBoxLayout()
|
||||
|
||||
self.highlighted_label = None
|
||||
self.all_labels = []
|
||||
|
||||
self.load_or_create_file_dict()
|
||||
|
||||
for i in range(len(self.file_dict)):
|
||||
label = QLabel(f'{self.file_dict["name"][i]}')
|
||||
# label.setFrameShape(QLabel.Panel)
|
||||
@ -198,13 +204,15 @@ class Bbox_correct_UI(QMainWindow):
|
||||
def switch_to_new_file(self, new_file, new_name):
|
||||
self.readout_rois()
|
||||
|
||||
self.setWindowTitle(f'efishSignalTracker | {new_name}')
|
||||
self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')
|
||||
|
||||
self.central_Layout.removeWidget(self.current_img)
|
||||
self.current_img = ImageWithBbox(new_file, parent=self)
|
||||
self.central_Layout.insertWidget(0, self.current_img, 4)
|
||||
|
||||
def readout_rois(self):
|
||||
if self.close_without_saving_rois:
|
||||
return
|
||||
new_labels = []
|
||||
for roi in self.current_img.ROIs:
|
||||
x0, y0 = roi.pos()
|
||||
@ -225,29 +233,33 @@ class Bbox_correct_UI(QMainWindow):
|
||||
fd = QFileDialog()
|
||||
export_path = fd.getExistingDirectory(self, 'Select Directory')
|
||||
if export_path:
|
||||
embed()
|
||||
quit()
|
||||
# ToDo: copy the validated files and delete from csv
|
||||
# ToDo: copy entries to new/exten csv
|
||||
# only files that are not in the training dataset so far
|
||||
print(export_path)
|
||||
export_idxs = list(self.file_dict['files'][self.file_dict['checked'] == 1].index)
|
||||
keep_idxs = []
|
||||
for export_file_path, export_idx in zip(list(self.file_dict['files'][self.file_dict['checked'] == 1]), export_idxs):
|
||||
export_image_path = pathlib.Path(export_file_path)
|
||||
export_label_path = export_image_path.parent.parent / 'labels' / pathlib.Path(export_image_path.name).with_suffix('.txt')
|
||||
|
||||
target_image_path = pathlib.Path(export_path) / 'images' / export_image_path.name
|
||||
target_label_path = pathlib.Path(export_path) / 'labels' / export_label_path.name
|
||||
if not target_image_path.exists():
|
||||
os.rename(export_image_path, target_image_path)
|
||||
os.rename(export_label_path, target_label_path)
|
||||
else:
|
||||
print('nope')
|
||||
pass
|
||||
keep_idxs.append(export_idx)
|
||||
|
||||
|
||||
# self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
|
||||
drop_idxs = list(set(export_idxs) - set(keep_idxs))
|
||||
self.file_dict = self.file_dict.drop(drop_idxs)
|
||||
self.save_file_dict()
|
||||
self.close_without_saving_rois = True
|
||||
self.close()
|
||||
|
||||
def import_data(self):
|
||||
fd = QFileDialog()
|
||||
import_path = fd.getExistingDirectory(self, 'Select Directory')
|
||||
if import_path:
|
||||
# ToDo: copy the UN validated files and add them to csv
|
||||
# only files that are not in the current folder
|
||||
print(import_path)
|
||||
else:
|
||||
print('nope')
|
||||
pass
|
||||
|
||||
def open(self):
|
||||
# ToDo: to be implemented
|
||||
pass
|
||||
|
||||
def init_MenuBar(self):
|
||||
|
73
datasets.py
73
datasets.py
@ -26,6 +26,8 @@ class InferenceDataset(Dataset):
|
||||
def __init__(self, dir_path):
|
||||
self.dir_path = dir_path
|
||||
self.all_images = sorted(list(Path(self.dir_path).rglob(f'*.png')))
|
||||
self.file_names = np.array([Path(x).with_suffix('') for x in self.all_images])
|
||||
# self.images = np.array(sorted(os.listdir(DATA_DIR)))
|
||||
def __len__(self):
|
||||
return len(self.all_images)
|
||||
|
||||
@ -39,45 +41,6 @@ class InferenceDataset(Dataset):
|
||||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
def __init__(self, dir_path, bbox_df):
|
||||
self.dir_path = dir_path
|
||||
self.bbox_df = bbox_df
|
||||
|
||||
self.all_images = np.array(sorted(pd.unique(self.bbox_df['image'])), dtype=str)
|
||||
self.image_paths = list(map(lambda x: Path(self.dir_path)/x, self.all_images))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image_name = self.all_images[idx]
|
||||
image_path = os.path.join(self.dir_path, image_name)
|
||||
|
||||
img = Image.open(image_path)
|
||||
img_tensor = F.to_tensor(img.convert('RGB'))
|
||||
|
||||
Cbbox = self.bbox_df[self.bbox_df['image'] == image_name]
|
||||
|
||||
labels = np.ones(len(Cbbox), dtype=int)
|
||||
boxes = torch.as_tensor(Cbbox.loc[:, ['x0', 'y0', 'x1', 'y1']].values, dtype=torch.float32)
|
||||
|
||||
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
||||
# no crowd instances
|
||||
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
|
||||
target["area"] = area
|
||||
target["iscrowd"] = iscrowd
|
||||
image_id = torch.tensor([idx])
|
||||
target["image_id"] = image_id
|
||||
target["image_name"] = image_name #ToDo: implement this as 3rd return...
|
||||
|
||||
return img_tensor, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_images)
|
||||
|
||||
|
||||
class CustomDataset_v2(Dataset):
|
||||
def __init__(self, limited_idxs=None):
|
||||
self.images = np.array(sorted(os.listdir(DATA_DIR)))
|
||||
self.labels = np.array(sorted(os.listdir(LABEL_DIR)))
|
||||
@ -99,7 +62,6 @@ class CustomDataset_v2(Dataset):
|
||||
|
||||
boxes, labels, area, iscrowd = self.extract_bboxes(annotations)
|
||||
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
|
||||
@ -147,27 +109,11 @@ def custom_train_test_split():
|
||||
train_idxs = np.sort(data_idxs[int(0.2 * len(data_idxs)):])
|
||||
test_idxs = np.sort(data_idxs[:int(0.2 * len(data_idxs))])
|
||||
|
||||
train_data = CustomDataset_v2(limited_idxs=train_idxs)
|
||||
test_data = CustomDataset_v2(limited_idxs=test_idxs)
|
||||
train_data = CustomDataset(limited_idxs=train_idxs)
|
||||
test_data = CustomDataset(limited_idxs=test_idxs)
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
def create_train_or_test_dataset(path, train=True):
|
||||
if train == True:
|
||||
pfx='train'
|
||||
print('Generate train dataset !')
|
||||
else:
|
||||
print('Generate test dataset !')
|
||||
pfx='test'
|
||||
|
||||
csv_candidates = list(Path(path).rglob(f'*{pfx}*.csv'))
|
||||
if len(csv_candidates) == 0:
|
||||
print(f'no .csv files for *{pfx}* found in {Path(path)}')
|
||||
quit()
|
||||
else:
|
||||
bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0)
|
||||
return CustomDataset(path, bboxes)
|
||||
|
||||
|
||||
def create_train_loader(train_dataset, num_workers=0):
|
||||
train_loader = DataLoader(
|
||||
@ -191,17 +137,6 @@ def create_valid_loader(valid_dataset, num_workers=0):
|
||||
)
|
||||
return valid_loader
|
||||
|
||||
def create_inference_loader(inference_dataset, num_workers=0):
|
||||
inference_loader = DataLoader(
|
||||
inference_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return inference_loader
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_data, test_data = custom_train_test_split()
|
||||
|
||||
|
@ -97,7 +97,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1,
|
||||
return fig_title
|
||||
|
||||
|
||||
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1):
|
||||
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1, label_save_folder):
|
||||
|
||||
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
|
||||
|
||||
@ -185,13 +185,16 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
|
||||
all_height
|
||||
]).T
|
||||
|
||||
np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
|
||||
np.savetxt(label_save_folder / Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
|
||||
return bbox_yolo_style
|
||||
|
||||
|
||||
def main(args):
|
||||
folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy'))
|
||||
pic_save_folder = DATA_DIR if not args.inference else (Path('data') / Path(args.folder).name)
|
||||
pic_save_folder = Path('data') / Path(args.folder).name / 'images'
|
||||
label_save_folder = Path('data') / Path(args.folder).name / 'labels'
|
||||
# embed()
|
||||
# quit()
|
||||
|
||||
if len(folders) == 0:
|
||||
print('no datasets containing fill_times.npy found')
|
||||
@ -202,8 +205,10 @@ def main(args):
|
||||
|
||||
else:
|
||||
print('generate inference dataset ... only image output')
|
||||
if not (Path('data') / Path(args.folder).name).exists():
|
||||
(Path('data') / Path(args.folder).name).mkdir(parents=True, exist_ok=True)
|
||||
if not pic_save_folder.exists():
|
||||
pic_save_folder.mkdir(parents=True, exist_ok=True)
|
||||
if not label_save_folder.exists():
|
||||
label_save_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for enu, folder in enumerate(folders):
|
||||
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
|
||||
@ -253,52 +258,7 @@ def main(args):
|
||||
if not args.inference:
|
||||
bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size,
|
||||
fish_baseline_freq_time, fish_baseline_freq,
|
||||
pic_save_str,t0, t1, f0, f1)
|
||||
|
||||
#######################################################################
|
||||
# if False:
|
||||
# if bbox_yolo_style.shape[0] >= 1:
|
||||
# f_res, t_res = freq[1] - freq[0], times[1] - times[0]
|
||||
#
|
||||
# fig_title = (
|
||||
# f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(
|
||||
# ' ', '0')
|
||||
# fig = plt.figure(figsize=IMG_SIZE, num=fig_title)
|
||||
# gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95) #
|
||||
# ax = fig.add_subplot(gs[0, 0])
|
||||
# ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
|
||||
# extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res))
|
||||
# # ax.invert_yaxis()
|
||||
# # ax.axis(False)
|
||||
#
|
||||
# for i in range(len(bbox_df)):
|
||||
# # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32)
|
||||
# Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']]
|
||||
# ax.add_patch(
|
||||
# Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])),
|
||||
# float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600,
|
||||
# float(Cbbox['f1']) - float(Cbbox['f0']),
|
||||
# fill=False, color="white", linestyle='-', linewidth=2, zorder=10)
|
||||
# )
|
||||
#
|
||||
# # print(bbox_yolo_style.T)
|
||||
#
|
||||
# for bbox in bbox_yolo_style:
|
||||
# x0 = bbox[1] - bbox[3]/2 # x_center - width/2
|
||||
# y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2
|
||||
# w = bbox[3]
|
||||
# h = bbox[4]
|
||||
# ax.add_patch(
|
||||
# Rectangle((x0, y0), w, h,
|
||||
# fill=False, color="k", linestyle='--', linewidth=2, zorder=10,
|
||||
# transform=ax.transAxes)
|
||||
# )
|
||||
# plt.show()
|
||||
#######################################################################
|
||||
|
||||
# if not args.inference:
|
||||
# print('save bboxes')
|
||||
# bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',')
|
||||
pic_save_str,t0, t1, f0, f1, label_save_folder)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
||||
@ -306,4 +266,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('-i', "--inference", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ToDo: put "images" in image folder
|
||||
|
||||
main(args)
|
24
inference.py
24
inference.py
@ -54,11 +54,35 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(images)
|
||||
# ToDo: save outputs in label folder !
|
||||
|
||||
for image, img_name, output in zip(images, img_names, outputs):
|
||||
# x0, y0, x1, y1
|
||||
|
||||
yolo_labels = []
|
||||
# for (x0, y0, x1, y1) in output['boxes'].cpu().numpy():
|
||||
for Cbbox, score in zip(output['boxes'].cpu().numpy(), output['scores'].cpu().numpy()):
|
||||
if score < detection_th:
|
||||
continue
|
||||
rel_x0 = Cbbox[0] / image.shape[-2]
|
||||
rel_y0 = Cbbox[1] / image.shape[-2]
|
||||
rel_x1 = Cbbox[2] / image.shape[-2]
|
||||
rel_y1 = Cbbox[3] / image.shape[-2]
|
||||
|
||||
rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2
|
||||
rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2
|
||||
rel_width = rel_x1 - rel_x0
|
||||
rel_height = rel_y1 - rel_y0
|
||||
|
||||
yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height])
|
||||
|
||||
label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt')
|
||||
np.savetxt(label_path, yolo_labels)
|
||||
|
||||
plot_inference(image, img_name, output, detection_th, dataset_name)
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
model = create_model(num_classes=NUM_CLASSES)
|
||||
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
|
||||
|
5
train.py
5
train.py
@ -121,11 +121,6 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold):
|
||||
# plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# train_data = create_train_or_test_dataset(DATA_DIR)
|
||||
# test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||
#
|
||||
# train_loader = create_train_loader(train_data)
|
||||
# test_loader = create_valid_loader(test_data)
|
||||
|
||||
train_data, test_data = custom_train_test_split()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user