check image save process
This commit is contained in:
parent
7f2d1cfb33
commit
391379551a
11
confic.py
11
confic.py
@ -6,6 +6,9 @@ RESIZE_TO = 416
|
||||
NUM_EPOCHS = 10
|
||||
NUM_WORKERS = 4
|
||||
|
||||
IMG_SIZE = (7, 7) # inches
|
||||
IMG_DPI = 256
|
||||
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
TRAIN_DIR = 'data/train'
|
||||
@ -15,7 +18,9 @@ CLASSES = ['__backgroud__', '1']
|
||||
NUM_CLASSES = len(CLASSES)
|
||||
|
||||
OUTDIR = 'model_outputs'
|
||||
|
||||
|
||||
if not pathlib.Path(OUTDIR).exists():
|
||||
pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True)
|
||||
pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
INFERENCE_OUTDIR = 'inference_outputs'
|
||||
if not pathlib.Path(INFERENCE_OUTDIR).exists():
|
||||
pathlib.Path(INFERENCE_OUTDIR).mkdir(parents=True, exist_ok=True)
|
12
inference.py
12
inference.py
@ -6,16 +6,17 @@ import os
|
||||
from PIL import Image
|
||||
|
||||
from model import create_model
|
||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR
|
||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
|
||||
from datasets import create_train_or_test_dataset, create_valid_loader
|
||||
|
||||
from IPython import embed
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
def plot_inference(img_tensor, img_name, output, target, detection_threshold):
|
||||
fig, ax = plt.subplots(figsize=(7, 7), num=img_name)
|
||||
fig, ax = plt.subplots(figsize=IMG_SIZE, num=img_name)
|
||||
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
|
||||
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
|
||||
if score < detection_threshold:
|
||||
@ -36,7 +37,12 @@ def plot_inference(img_tensor, img_name, output, target, detection_threshold):
|
||||
fill=False, color="white", linewidth=2, zorder=9)
|
||||
)
|
||||
|
||||
plt.show()
|
||||
ax.set_axis_off()
|
||||
embed()
|
||||
quit()
|
||||
plt.savefig(Path(INFERENCE_OUTDIR)/img_name/'_inferred.png', IMG_DPI)
|
||||
plt.close()
|
||||
# plt.show()
|
||||
|
||||
def infere_model(test_loader, model, detection_th=0.8):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user