inference rework
This commit is contained in:
parent
43f0e0a27d
commit
3690907f93
33
datasets.py
33
datasets.py
@ -19,6 +19,25 @@ from custom_utils import collate_fn
|
||||
|
||||
from IPython import embed
|
||||
|
||||
|
||||
class InferenceDataset(Dataset):
|
||||
def __init__(self, dir_path):
|
||||
self.dir_path = dir_path
|
||||
self.all_images = sorted(list(Path(self.dir_path).rglob(f'*.png')))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image_name = self.all_images[idx]
|
||||
image_path = os.path.join(self.dir_path, image_name)
|
||||
|
||||
img = Image.open(image_path)
|
||||
img_tensor = F.to_tensor(img.convert('RGB'))
|
||||
|
||||
return img_tensor
|
||||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
def __init__(self, dir_path, bbox_df):
|
||||
self.dir_path = dir_path
|
||||
@ -57,6 +76,7 @@ class CustomDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.all_images)
|
||||
|
||||
|
||||
def create_train_or_test_dataset(path, train=True):
|
||||
if train == True:
|
||||
pfx='train'
|
||||
@ -73,6 +93,7 @@ def create_train_or_test_dataset(path, train=True):
|
||||
bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0)
|
||||
return CustomDataset(path, bboxes)
|
||||
|
||||
|
||||
def create_train_loader(train_dataset, num_workers=0):
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
@ -82,6 +103,8 @@ def create_train_loader(train_dataset, num_workers=0):
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return train_loader
|
||||
|
||||
|
||||
def create_valid_loader(valid_dataset, num_workers=0):
|
||||
valid_loader = DataLoader(
|
||||
valid_dataset,
|
||||
@ -92,6 +115,16 @@ def create_valid_loader(valid_dataset, num_workers=0):
|
||||
)
|
||||
return valid_loader
|
||||
|
||||
def create_inference_loader(inference_dataset, num_workers=0)
|
||||
inference_loader = DataLoader(
|
||||
inference_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return inference_loader
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
30
inference.py
30
inference.py
@ -4,10 +4,11 @@ import torchvision.transforms.functional as F
|
||||
import glob
|
||||
import os
|
||||
from PIL import Image
|
||||
import argparse
|
||||
|
||||
from model import create_model
|
||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, DATA_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
|
||||
from datasets import create_train_or_test_dataset, create_valid_loader
|
||||
from datasets import InferenceDataset, create_inference_loader
|
||||
|
||||
from IPython import embed
|
||||
from pathlib import Path
|
||||
@ -41,11 +42,11 @@ def plot_inference(img_tensor, img_name, output, detection_threshold):
|
||||
plt.close()
|
||||
# plt.show()
|
||||
|
||||
def infere_model(test_loader, model, detection_th=0.8):
|
||||
def infere_model(inference_loader, model, detection_th=0.8):
|
||||
|
||||
print('Validation')
|
||||
print('Inference')
|
||||
|
||||
prog_bar = tqdm(test_loader, total=len(test_loader))
|
||||
prog_bar = tqdm(inference_loader, total=len(inference_loader))
|
||||
for samples, targets in prog_bar:
|
||||
images = list(image.to(DEVICE) for image in samples)
|
||||
|
||||
@ -56,19 +57,21 @@ def infere_model(test_loader, model, detection_th=0.8):
|
||||
outputs = model(images)
|
||||
|
||||
for image, img_name, output, target in zip(images, img_names, outputs, targets):
|
||||
plot_inference(image, img_name, output, target, detection_th)
|
||||
plot_inference(image, img_name, output, detection_th)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def main(args):
|
||||
model = create_model(num_classes=NUM_CLASSES)
|
||||
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.to(DEVICE).eval()
|
||||
|
||||
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||
test_loader = create_valid_loader(test_data)
|
||||
inference_data = InferenceDataset(args.folder)
|
||||
inference_loader = create_inference_loader(inference_data)
|
||||
|
||||
infere_model(test_loader, model)
|
||||
embed()
|
||||
quit()
|
||||
infere_model(inference_loader, model)
|
||||
|
||||
# detection_threshold = 0.8
|
||||
# frame_count = 0
|
||||
@ -86,4 +89,11 @@ if __name__ == '__main__':
|
||||
#
|
||||
# print(len(outputs[0]['boxes']))
|
||||
|
||||
# show_sample(img_tensor, outputs, detection_threshold)
|
||||
# show_sample(img_tensor, outputs, detection_threshold)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
||||
parser.add_argument('folder', type=str, help='folder to infer picutes', default='')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
Loading…
Reference in New Issue
Block a user