check image save process

This commit is contained in:
Till Raab 2023-10-26 10:06:55 +02:00
parent 7f2d1cfb33
commit 391379551a
2 changed files with 17 additions and 6 deletions

View File

@ -6,6 +6,9 @@ RESIZE_TO = 416
NUM_EPOCHS = 10 NUM_EPOCHS = 10
NUM_WORKERS = 4 NUM_WORKERS = 4
IMG_SIZE = (7, 7) # inches
IMG_DPI = 256
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
TRAIN_DIR = 'data/train' TRAIN_DIR = 'data/train'
@ -15,7 +18,9 @@ CLASSES = ['__backgroud__', '1']
NUM_CLASSES = len(CLASSES) NUM_CLASSES = len(CLASSES)
OUTDIR = 'model_outputs' OUTDIR = 'model_outputs'
if not pathlib.Path(OUTDIR).exists(): if not pathlib.Path(OUTDIR).exists():
pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True) pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True)
INFERENCE_OUTDIR = 'inference_outputs'
if not pathlib.Path(INFERENCE_OUTDIR).exists():
pathlib.Path(INFERENCE_OUTDIR).mkdir(parents=True, exist_ok=True)

View File

@ -6,16 +6,17 @@ import os
from PIL import Image from PIL import Image
from model import create_model from model import create_model
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
from datasets import create_train_or_test_dataset, create_valid_loader from datasets import create_train_or_test_dataset, create_valid_loader
from IPython import embed from IPython import embed
from pathlib import Path
from tqdm.auto import tqdm from tqdm.auto import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
def plot_inference(img_tensor, img_name, output, target, detection_threshold): def plot_inference(img_tensor, img_name, output, target, detection_threshold):
fig, ax = plt.subplots(figsize=(7, 7), num=img_name) fig, ax = plt.subplots(figsize=IMG_SIZE, num=img_name)
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto') ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
if score < detection_threshold: if score < detection_threshold:
@ -36,7 +37,12 @@ def plot_inference(img_tensor, img_name, output, target, detection_threshold):
fill=False, color="white", linewidth=2, zorder=9) fill=False, color="white", linewidth=2, zorder=9)
) )
plt.show() ax.set_axis_off()
embed()
quit()
plt.savefig(Path(INFERENCE_OUTDIR)/img_name/'_inferred.png', IMG_DPI)
plt.close()
# plt.show()
def infere_model(test_loader, model, detection_th=0.8): def infere_model(test_loader, model, detection_th=0.8):