check image save process

This commit is contained in:
Till Raab 2023-10-26 10:06:55 +02:00
parent 7f2d1cfb33
commit 391379551a
2 changed files with 17 additions and 6 deletions

View File

@ -6,6 +6,9 @@ RESIZE_TO = 416
NUM_EPOCHS = 10
NUM_WORKERS = 4
IMG_SIZE = (7, 7) # inches
IMG_DPI = 256
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
TRAIN_DIR = 'data/train'
@ -15,7 +18,9 @@ CLASSES = ['__backgroud__', '1']
NUM_CLASSES = len(CLASSES)
OUTDIR = 'model_outputs'
if not pathlib.Path(OUTDIR).exists():
pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True)
pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True)
INFERENCE_OUTDIR = 'inference_outputs'
if not pathlib.Path(INFERENCE_OUTDIR).exists():
pathlib.Path(INFERENCE_OUTDIR).mkdir(parents=True, exist_ok=True)

View File

@ -6,16 +6,17 @@ import os
from PIL import Image
from model import create_model
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
from datasets import create_train_or_test_dataset, create_valid_loader
from IPython import embed
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
def plot_inference(img_tensor, img_name, output, target, detection_threshold):
fig, ax = plt.subplots(figsize=(7, 7), num=img_name)
fig, ax = plt.subplots(figsize=IMG_SIZE, num=img_name)
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
if score < detection_threshold:
@ -36,7 +37,12 @@ def plot_inference(img_tensor, img_name, output, target, detection_threshold):
fill=False, color="white", linewidth=2, zorder=9)
)
plt.show()
ax.set_axis_off()
embed()
quit()
plt.savefig(Path(INFERENCE_OUTDIR)/img_name/'_inferred.png', IMG_DPI)
plt.close()
# plt.show()
def infere_model(test_loader, model, detection_th=0.8):