the model is working. however, in need to generate it and split data in test and train beforehand ...
This commit is contained in:
parent
8a6e7df57f
commit
ad74322f94
@ -1,9 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 8
|
||||||
RESIZE_TO = 416
|
RESIZE_TO = 416
|
||||||
NUM_EPOCHS = 20
|
NUM_EPOCHS = 10
|
||||||
NUM_WORKERS = 4
|
NUM_WORKERS = 4
|
||||||
|
|
||||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
26
inference.py
26
inference.py
@ -10,6 +10,28 @@ from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR
|
|||||||
|
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
|
def show_sample(img_tensor, outputs, detection_threshold):
|
||||||
|
|
||||||
|
# embed()
|
||||||
|
# quit()
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto')
|
||||||
|
for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()):
|
||||||
|
|
||||||
|
if score < detection_threshold:
|
||||||
|
continue
|
||||||
|
# print(x0, y0, x1, y1, l)
|
||||||
|
ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white')
|
||||||
|
ax.add_patch(
|
||||||
|
Rectangle((x0, y0),
|
||||||
|
(x1 - x0),
|
||||||
|
(y1 - y0),
|
||||||
|
fill=False, color="white", linewidth=2, zorder=10)
|
||||||
|
)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = create_model(num_classes=NUM_CLASSES)
|
model = create_model(num_classes=NUM_CLASSES)
|
||||||
@ -18,6 +40,8 @@ if __name__ == '__main__':
|
|||||||
model.to(DEVICE).eval()
|
model.to(DEVICE).eval()
|
||||||
|
|
||||||
DIR_TEST = 'data/train'
|
DIR_TEST = 'data/train'
|
||||||
|
|
||||||
|
|
||||||
test_images = glob.glob(f"{DIR_TEST}/*.png")
|
test_images = glob.glob(f"{DIR_TEST}/*.png")
|
||||||
|
|
||||||
detection_threshold = 0.8
|
detection_threshold = 0.8
|
||||||
@ -35,3 +59,5 @@ if __name__ == '__main__':
|
|||||||
outputs = model(img_tensor.to(DEVICE))
|
outputs = model(img_tensor.to(DEVICE))
|
||||||
|
|
||||||
print(len(outputs[0]['boxes']))
|
print(len(outputs[0]['boxes']))
|
||||||
|
|
||||||
|
show_sample(img_tensor, outputs, detection_threshold)
|
2
train.py
2
train.py
@ -45,8 +45,6 @@ def validate(test_loader, model, val_loss):
|
|||||||
|
|
||||||
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
|
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
|
||||||
|
|
||||||
embed()
|
|
||||||
quit()
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
loss_dict = model(images, targets)
|
loss_dict = model(images, targets)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user