inference function with dataloader to optimizer inference. now generate plots ...
This commit is contained in:
parent
f507877624
commit
fe6e51d625
48
inference.py
48
inference.py
@ -33,7 +33,19 @@ def show_sample(img_tensor, outputs, detection_threshold):
|
||||
plt.show()
|
||||
|
||||
def infere_model(test_loader, model, detection_th=0.8):
|
||||
pass
|
||||
|
||||
print('Validation')
|
||||
|
||||
prog_bar = tqdm(test_loader, total=len(test_loader))
|
||||
for samples, targets in prog_bar:
|
||||
images = list(image.to(DEVICE) for image in samples)
|
||||
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(images)
|
||||
|
||||
embed()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = create_model(num_classes=NUM_CLASSES)
|
||||
@ -41,26 +53,26 @@ if __name__ == '__main__':
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.to(DEVICE).eval()
|
||||
|
||||
# test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||
# test_loader = create_valid_loader(test_data)
|
||||
|
||||
|
||||
# infere_model(test_loader, model)
|
||||
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||
test_loader = create_valid_loader(test_data)
|
||||
|
||||
detection_threshold = 0.8
|
||||
frame_count = 0
|
||||
total_fps = 0
|
||||
test_images = glob.glob(f"{TRAIN_DIR}/*.png")
|
||||
|
||||
for i in tqdm(np.arange(len(test_images))):
|
||||
image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
|
||||
infere_model(test_loader, model)
|
||||
|
||||
img = Image.open(test_images[i])
|
||||
img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(img_tensor.to(DEVICE))
|
||||
# detection_threshold = 0.8
|
||||
# frame_count = 0
|
||||
# total_fps = 0
|
||||
# test_images = glob.glob(f"{TRAIN_DIR}/*.png")
|
||||
|
||||
print(len(outputs[0]['boxes']))
|
||||
# for i in tqdm(np.arange(len(test_images))):
|
||||
# image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
|
||||
#
|
||||
# img = Image.open(test_images[i])
|
||||
# img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
|
||||
#
|
||||
# with torch.inference_mode():
|
||||
# outputs = model(img_tensor.to(DEVICE))
|
||||
#
|
||||
# print(len(outputs[0]['boxes']))
|
||||
|
||||
# show_sample(img_tensor, outputs, detection_threshold)
|
Loading…
Reference in New Issue
Block a user