inference function with dataloader to optimizer inference. now generate plots ...
This commit is contained in:
parent
f507877624
commit
fe6e51d625
48
inference.py
48
inference.py
@ -33,7 +33,19 @@ def show_sample(img_tensor, outputs, detection_threshold):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def infere_model(test_loader, model, detection_th=0.8):
|
def infere_model(test_loader, model, detection_th=0.8):
|
||||||
pass
|
|
||||||
|
print('Validation')
|
||||||
|
|
||||||
|
prog_bar = tqdm(test_loader, total=len(test_loader))
|
||||||
|
for samples, targets in prog_bar:
|
||||||
|
images = list(image.to(DEVICE) for image in samples)
|
||||||
|
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = model(images)
|
||||||
|
|
||||||
|
embed()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = create_model(num_classes=NUM_CLASSES)
|
model = create_model(num_classes=NUM_CLASSES)
|
||||||
@ -41,26 +53,26 @@ if __name__ == '__main__':
|
|||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
model.to(DEVICE).eval()
|
model.to(DEVICE).eval()
|
||||||
|
|
||||||
# test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||||
# test_loader = create_valid_loader(test_data)
|
test_loader = create_valid_loader(test_data)
|
||||||
|
|
||||||
|
|
||||||
# infere_model(test_loader, model)
|
|
||||||
|
|
||||||
detection_threshold = 0.8
|
|
||||||
frame_count = 0
|
|
||||||
total_fps = 0
|
|
||||||
test_images = glob.glob(f"{TRAIN_DIR}/*.png")
|
|
||||||
|
|
||||||
for i in tqdm(np.arange(len(test_images))):
|
infere_model(test_loader, model)
|
||||||
image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
|
|
||||||
|
|
||||||
img = Image.open(test_images[i])
|
# detection_threshold = 0.8
|
||||||
img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
|
# frame_count = 0
|
||||||
|
# total_fps = 0
|
||||||
with torch.inference_mode():
|
# test_images = glob.glob(f"{TRAIN_DIR}/*.png")
|
||||||
outputs = model(img_tensor.to(DEVICE))
|
|
||||||
|
|
||||||
print(len(outputs[0]['boxes']))
|
# for i in tqdm(np.arange(len(test_images))):
|
||||||
|
# image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
|
||||||
|
#
|
||||||
|
# img = Image.open(test_images[i])
|
||||||
|
# img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
|
||||||
|
#
|
||||||
|
# with torch.inference_mode():
|
||||||
|
# outputs = model(img_tensor.to(DEVICE))
|
||||||
|
#
|
||||||
|
# print(len(outputs[0]['boxes']))
|
||||||
|
|
||||||
# show_sample(img_tensor, outputs, detection_threshold)
|
# show_sample(img_tensor, outputs, detection_threshold)
|
Loading…
Reference in New Issue
Block a user