train test split implemented using different csv files
This commit is contained in:
parent
ecf110e051
commit
e1a97ac493
47
datasets.py
47
datasets.py
@ -20,14 +20,21 @@ from custom_utils import collate_fn
|
|||||||
from IPython import embed
|
from IPython import embed
|
||||||
|
|
||||||
class CustomDataset(Dataset):
|
class CustomDataset(Dataset):
|
||||||
def __init__(self, dir_path, use_idxs = None):
|
def __init__(self, dir_path, bbox_df):
|
||||||
self.dir_path = dir_path
|
self.dir_path = dir_path
|
||||||
self.image_paths = glob.glob(f'{self.dir_path}/*.png')
|
self.bbox_df = bbox_df
|
||||||
self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths]
|
|
||||||
self.all_images = np.array(sorted(self.all_images), dtype=str)
|
self.all_images = np.array(sorted(self.bbox_df['image']), dtype=str)
|
||||||
if hasattr(use_idxs, '__len__'):
|
self.image_paths = list(map(lambda x: Path(self.dir_path)/x, self.all_images))
|
||||||
self.all_images = self.all_images[use_idxs]
|
# embed()
|
||||||
self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0)
|
# quit()
|
||||||
|
|
||||||
|
# self.image_paths = glob.glob(f'{self.dir_path}/*.png')
|
||||||
|
# self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths]
|
||||||
|
# self.all_images = np.array(sorted(self.all_images), dtype=str)
|
||||||
|
# if hasattr(use_idxs, '__len__'):
|
||||||
|
# self.all_images = self.all_images[use_idxs]
|
||||||
|
# self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
image_name = self.all_images[idx]
|
image_name = self.all_images[idx]
|
||||||
@ -66,11 +73,27 @@ def create_train_test_dataset(path, test_size=0.2):
|
|||||||
train_idx = train_test_idx[int(test_size*len(train_test_idx)):]
|
train_idx = train_test_idx[int(test_size*len(train_test_idx)):]
|
||||||
test_idx = train_test_idx[:int(test_size*len(train_test_idx))]
|
test_idx = train_test_idx[:int(test_size*len(train_test_idx))]
|
||||||
|
|
||||||
train_data = CustomDataset(path, use_idxs=train_idx)
|
train_data = CustomDataset(path)
|
||||||
test_data = CustomDataset(path, use_idxs=test_idx)
|
test_data = CustomDataset(path)
|
||||||
|
|
||||||
return train_data, test_data
|
return train_data, test_data
|
||||||
|
|
||||||
|
def create_train_or_test_dataset(path, train=True):
|
||||||
|
if train == True:
|
||||||
|
pfx='train'
|
||||||
|
print('Generate train dataset !')
|
||||||
|
else:
|
||||||
|
print('Generate test dataset !')
|
||||||
|
pfx='test'
|
||||||
|
|
||||||
|
csv_candidates = list(Path(path).rglob(f'*{pfx}*.csv'))
|
||||||
|
if len(csv_candidates) == 0:
|
||||||
|
print(f'no .csv files for *{pfx}* found in {Path(path)}')
|
||||||
|
quit()
|
||||||
|
else:
|
||||||
|
bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0)
|
||||||
|
return CustomDataset(path, bboxes)
|
||||||
|
|
||||||
def create_train_loader(train_dataset, num_workers=0):
|
def create_train_loader(train_dataset, num_workers=0):
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@ -93,12 +116,14 @@ def create_valid_loader(valid_dataset, num_workers=0):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
train_data, test_data = create_train_test_dataset(TRAIN_DIR)
|
# train_data, test_data = create_train_test_dataset(TRAIN_DIR)
|
||||||
|
train_data = create_train_or_test_dataset(TRAIN_DIR)
|
||||||
|
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||||
|
|
||||||
train_loader = create_train_loader(train_data)
|
train_loader = create_train_loader(train_data)
|
||||||
test_loader = create_valid_loader(test_data)
|
test_loader = create_valid_loader(test_data)
|
||||||
|
|
||||||
for samples, targets in train_loader:
|
for samples, targets in test_loader:
|
||||||
for s, t in zip(samples, targets):
|
for s, t in zip(samples, targets):
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
ax.imshow(s.permute(1, 2, 0), aspect='auto')
|
ax.imshow(s.permute(1, 2, 0), aspect='auto')
|
||||||
|
6
train.py
6
train.py
@ -3,7 +3,7 @@ from model import create_model
|
|||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from datasets import create_train_test_dataset, create_train_loader, create_valid_loader
|
from datasets import create_train_loader, create_valid_loader, create_train_or_test_dataset
|
||||||
from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot
|
from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -62,7 +62,9 @@ def validate(test_loader, model, val_loss):
|
|||||||
return val_loss
|
return val_loss
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_data, test_data = create_train_test_dataset(TRAIN_DIR)
|
train_data = create_train_or_test_dataset(TRAIN_DIR)
|
||||||
|
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||||
|
|
||||||
train_loader = create_train_loader(train_data)
|
train_loader = create_train_loader(train_data)
|
||||||
test_loader = create_train_loader(test_data)
|
test_loader = create_train_loader(test_data)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user