inference function with dataloader to optimizer inference. now generate plots ...

This commit is contained in:
Till Raab 2023-10-26 07:35:41 +02:00
parent f507877624
commit fe6e51d625

View File

@ -33,7 +33,19 @@ def show_sample(img_tensor, outputs, detection_threshold):
plt.show()
def infere_model(test_loader, model, detection_th=0.8):
pass
print('Validation')
prog_bar = tqdm(test_loader, total=len(test_loader))
for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
with torch.inference_mode():
outputs = model(images)
embed()
if __name__ == '__main__':
model = create_model(num_classes=NUM_CLASSES)
@ -41,26 +53,26 @@ if __name__ == '__main__':
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()
# test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
# test_loader = create_valid_loader(test_data)
# infere_model(test_loader, model)
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
test_loader = create_valid_loader(test_data)
detection_threshold = 0.8
frame_count = 0
total_fps = 0
test_images = glob.glob(f"{TRAIN_DIR}/*.png")
for i in tqdm(np.arange(len(test_images))):
image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
infere_model(test_loader, model)
img = Image.open(test_images[i])
img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
with torch.inference_mode():
outputs = model(img_tensor.to(DEVICE))
# detection_threshold = 0.8
# frame_count = 0
# total_fps = 0
# test_images = glob.glob(f"{TRAIN_DIR}/*.png")
print(len(outputs[0]['boxes']))
# for i in tqdm(np.arange(len(test_images))):
# image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
#
# img = Image.open(test_images[i])
# img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
#
# with torch.inference_mode():
# outputs = model(img_tensor.to(DEVICE))
#
# print(len(outputs[0]['boxes']))
# show_sample(img_tensor, outputs, detection_threshold)