efishSignalDetector/datasets.py

105 lines
3.3 KiB
Python

import os
import glob
import torch
import torchvision
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import Rectangle
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
from confic import (CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE)
from custom_utils import collate_fn
from IPython import embed
class CustomDataset(Dataset):
def __init__(self, dir_path, use_idxs = None):
self.dir_path = dir_path
self.image_paths = glob.glob(f'{self.dir_path}/*.png')
self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths]
self.all_images = np.array(sorted(self.all_images), dtype=str)
if hasattr(use_idxs, '__len__'):
self.all_images = self.all_images[use_idxs]
self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0)
def __getitem__(self, idx):
image_name = self.all_images[idx]
image_path = os.path.join(self.dir_path, image_name)
img = Image.open(image_path)
img_tensor = F.to_tensor(img.convert('RGB'))
Cbbox = self.bbox_df[self.bbox_df['image'] == image_name]
labels = np.ones(len(Cbbox), dtype=int)
boxes = Cbbox.loc[:, ['x0', 'x1', 'y0', 'y1']].values
target = {}
target["boxes"] = boxes
target["labels"] = labels
return img_tensor, target
def __len__(self):
return len(self.all_images)
def create_train_test_dataset(path, test_size=0.2):
files = glob.glob(os.path.join(path, '*.png'))
train_test_idx = np.arange(len(files), dtype=int)
np.random.shuffle(train_test_idx)
train_idx = train_test_idx[int(test_size*len(train_test_idx)):]
test_idx = train_test_idx[:int(test_size*len(train_test_idx))]
train_data = CustomDataset(path, use_idxs=train_idx)
test_data = CustomDataset(path, use_idxs=test_idx)
return train_data, test_data
def create_train_loader(train_dataset, num_workers=0):
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn
)
return train_loader
def create_valid_loader(valid_dataset, num_workers=0):
valid_loader = DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn
)
return valid_loader
if __name__ == '__main__':
train_data, test_data = create_train_test_dataset(TRAIN_DIR)
train_loader = create_train_loader(train_data)
test_loader = create_valid_loader(test_data)
for samples, targets in train_loader:
for s, t in zip(samples, targets):
fig, ax = plt.subplots()
ax.imshow(s.permute(1, 2, 0), aspect='auto')
for (x0, x1, y0, y1), l in zip(t['boxes'], t['labels']):
print(x0, x1, y0, y1, l)
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="white", linewidth=2, zorder=10)
)
plt.show()