check image save process
This commit is contained in:
parent
7f2d1cfb33
commit
391379551a
11
confic.py
11
confic.py
@ -6,6 +6,9 @@ RESIZE_TO = 416
|
|||||||
NUM_EPOCHS = 10
|
NUM_EPOCHS = 10
|
||||||
NUM_WORKERS = 4
|
NUM_WORKERS = 4
|
||||||
|
|
||||||
|
IMG_SIZE = (7, 7) # inches
|
||||||
|
IMG_DPI = 256
|
||||||
|
|
||||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
TRAIN_DIR = 'data/train'
|
TRAIN_DIR = 'data/train'
|
||||||
@ -15,7 +18,9 @@ CLASSES = ['__backgroud__', '1']
|
|||||||
NUM_CLASSES = len(CLASSES)
|
NUM_CLASSES = len(CLASSES)
|
||||||
|
|
||||||
OUTDIR = 'model_outputs'
|
OUTDIR = 'model_outputs'
|
||||||
|
|
||||||
|
|
||||||
if not pathlib.Path(OUTDIR).exists():
|
if not pathlib.Path(OUTDIR).exists():
|
||||||
pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True)
|
pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
INFERENCE_OUTDIR = 'inference_outputs'
|
||||||
|
if not pathlib.Path(INFERENCE_OUTDIR).exists():
|
||||||
|
pathlib.Path(INFERENCE_OUTDIR).mkdir(parents=True, exist_ok=True)
|
12
inference.py
12
inference.py
@ -6,16 +6,17 @@ import os
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from model import create_model
|
from model import create_model
|
||||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR
|
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
|
||||||
from datasets import create_train_or_test_dataset, create_valid_loader
|
from datasets import create_train_or_test_dataset, create_valid_loader
|
||||||
|
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
|
from pathlib import Path
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.patches import Rectangle
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
def plot_inference(img_tensor, img_name, output, target, detection_threshold):
|
def plot_inference(img_tensor, img_name, output, target, detection_threshold):
|
||||||
fig, ax = plt.subplots(figsize=(7, 7), num=img_name)
|
fig, ax = plt.subplots(figsize=IMG_SIZE, num=img_name)
|
||||||
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
|
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
|
||||||
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
|
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
|
||||||
if score < detection_threshold:
|
if score < detection_threshold:
|
||||||
@ -36,7 +37,12 @@ def plot_inference(img_tensor, img_name, output, target, detection_threshold):
|
|||||||
fill=False, color="white", linewidth=2, zorder=9)
|
fill=False, color="white", linewidth=2, zorder=9)
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.show()
|
ax.set_axis_off()
|
||||||
|
embed()
|
||||||
|
quit()
|
||||||
|
plt.savefig(Path(INFERENCE_OUTDIR)/img_name/'_inferred.png', IMG_DPI)
|
||||||
|
plt.close()
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
def infere_model(test_loader, model, detection_th=0.8):
|
def infere_model(test_loader, model, detection_th=0.8):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user