diff --git a/datasets.py b/datasets.py index 54e992c..df0ee01 100644 --- a/datasets.py +++ b/datasets.py @@ -20,14 +20,21 @@ from custom_utils import collate_fn from IPython import embed class CustomDataset(Dataset): - def __init__(self, dir_path, use_idxs = None): + def __init__(self, dir_path, bbox_df): self.dir_path = dir_path - self.image_paths = glob.glob(f'{self.dir_path}/*.png') - self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths] - self.all_images = np.array(sorted(self.all_images), dtype=str) - if hasattr(use_idxs, '__len__'): - self.all_images = self.all_images[use_idxs] - self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0) + self.bbox_df = bbox_df + + self.all_images = np.array(sorted(self.bbox_df['image']), dtype=str) + self.image_paths = list(map(lambda x: Path(self.dir_path)/x, self.all_images)) + # embed() + # quit() + + # self.image_paths = glob.glob(f'{self.dir_path}/*.png') + # self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths] + # self.all_images = np.array(sorted(self.all_images), dtype=str) + # if hasattr(use_idxs, '__len__'): + # self.all_images = self.all_images[use_idxs] + # self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0) def __getitem__(self, idx): image_name = self.all_images[idx] @@ -66,11 +73,27 @@ def create_train_test_dataset(path, test_size=0.2): train_idx = train_test_idx[int(test_size*len(train_test_idx)):] test_idx = train_test_idx[:int(test_size*len(train_test_idx))] - train_data = CustomDataset(path, use_idxs=train_idx) - test_data = CustomDataset(path, use_idxs=test_idx) + train_data = CustomDataset(path) + test_data = CustomDataset(path) return train_data, test_data +def create_train_or_test_dataset(path, train=True): + if train == True: + pfx='train' + print('Generate train dataset !') + else: + print('Generate test dataset !') + pfx='test' + + csv_candidates = list(Path(path).rglob(f'*{pfx}*.csv')) + if len(csv_candidates) == 0: + print(f'no .csv files for *{pfx}* found in {Path(path)}') + quit() + else: + bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0) + return CustomDataset(path, bboxes) + def create_train_loader(train_dataset, num_workers=0): train_loader = DataLoader( train_dataset, @@ -93,12 +116,14 @@ def create_valid_loader(valid_dataset, num_workers=0): if __name__ == '__main__': - train_data, test_data = create_train_test_dataset(TRAIN_DIR) + # train_data, test_data = create_train_test_dataset(TRAIN_DIR) + train_data = create_train_or_test_dataset(TRAIN_DIR) + test_data = create_train_or_test_dataset(TRAIN_DIR, train=False) train_loader = create_train_loader(train_data) test_loader = create_valid_loader(test_data) - for samples, targets in train_loader: + for samples, targets in test_loader: for s, t in zip(samples, targets): fig, ax = plt.subplots() ax.imshow(s.permute(1, 2, 0), aspect='auto') diff --git a/train.py b/train.py index 33ea54e..b7bc197 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,7 @@ from model import create_model from tqdm.auto import tqdm -from datasets import create_train_test_dataset, create_train_loader, create_valid_loader +from datasets import create_train_loader, create_valid_loader, create_train_or_test_dataset from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot import torch @@ -62,7 +62,9 @@ def validate(test_loader, model, val_loss): return val_loss if __name__ == '__main__': - train_data, test_data = create_train_test_dataset(TRAIN_DIR) + train_data = create_train_or_test_dataset(TRAIN_DIR) + test_data = create_train_or_test_dataset(TRAIN_DIR, train=False) + train_loader = create_train_loader(train_data) test_loader = create_train_loader(test_data)