fixed image name problematic

This commit is contained in:
Till Raab 2023-10-26 09:35:15 +02:00
parent 0adfc22ec4
commit 7f2d1cfb33
2 changed files with 11 additions and 9 deletions

View File

@ -14,8 +14,8 @@ from tqdm.auto import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
def plot_inference(img_tensor, output, target, detection_threshold): def plot_inference(img_tensor, img_name, output, target, detection_threshold):
fig, ax = plt.subplots(figsize=(7, 7), num=target['image_id']) fig, ax = plt.subplots(figsize=(7, 7), num=img_name)
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto') ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
if score < detection_threshold: if score < detection_threshold:
@ -46,15 +46,14 @@ def infere_model(test_loader, model, detection_th=0.8):
for samples, targets in prog_bar: for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples) images = list(image.to(DEVICE) for image in samples)
embed() img_names = [t['image_name'] for t in targets]
quit() targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets]
targets = [{k: v for k, v in t.items()} for t in targets]
with torch.inference_mode(): with torch.inference_mode():
outputs = model(images) outputs = model(images)
for image, output, target in zip(images, outputs, targets): for image, img_name, output, target in zip(images, img_names, outputs, targets):
plot_inference(image, output, target, detection_th) plot_inference(image, img_name, output, target, detection_th)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -19,7 +19,8 @@ def train(train_loader, model, optimizer, train_loss):
for samples, targets in prog_bar: for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples) images = list(image.to(DEVICE) for image in samples)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
loss_dict = model(images, targets) loss_dict = model(images, targets)
@ -43,7 +44,9 @@ def validate(test_loader, model, val_loss):
for samples, targets in prog_bar: for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples) images = list(image.to(DEVICE) for image in samples)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
with torch.inference_mode(): with torch.inference_mode():
loss_dict = model(images, targets) loss_dict = model(images, targets)