139 lines
5.1 KiB
Python
139 lines
5.1 KiB
Python
import numpy as np
|
|
import torch
|
|
import torchvision.transforms.functional as F
|
|
import glob
|
|
import os
|
|
from PIL import Image
|
|
import argparse
|
|
|
|
from model import create_model
|
|
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, DATA_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
|
|
from datasets import InferenceDataset, create_valid_loader
|
|
|
|
from IPython import embed
|
|
from pathlib import Path
|
|
from tqdm.auto import tqdm
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.gridspec as gridspec
|
|
from matplotlib.patches import Rectangle
|
|
|
|
|
|
def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
|
|
# embed()
|
|
# quit()
|
|
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
|
|
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
|
|
ax = fig.add_subplot(gs[0, 0])
|
|
|
|
# ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot')
|
|
ax.imshow(img_tensor.cpu().squeeze()[0], aspect='auto', cmap='afmhot', vmin=.2)
|
|
|
|
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
|
|
# embed()
|
|
# quit()
|
|
if score < detection_threshold:
|
|
continue
|
|
# print(x0, y0, x1, y1, l)
|
|
# print(score)
|
|
ax.text(x0 + (x1 - x0) / 2, y0, f'{score:.2f}', ha='center', va='bottom', fontsize=12, color='tab:gray', rotation=90)
|
|
ax.add_patch(
|
|
Rectangle((x0, y0),
|
|
(x1 - x0),
|
|
(y1 - y0),
|
|
fill=False, color="tab:gray", linestyle='-', linewidth=1, zorder=10, alpha=0.8)
|
|
)
|
|
|
|
ax.set_axis_off()
|
|
plt.savefig(Path(INFERENCE_OUTDIR)/dataset_name/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
|
|
plt.close()
|
|
# plt.show()
|
|
|
|
def infere_model(inference_loader, model, dataset_name, detection_th=0.8, figures_only=False):
|
|
|
|
print(f'Inference on dataset: {dataset_name}')
|
|
|
|
prog_bar = tqdm(inference_loader, total=len(inference_loader))
|
|
for samples, img_names in prog_bar:
|
|
images = list(image.to(DEVICE) for image in samples)
|
|
|
|
# img_names = [t['image_name'] for t in targets]
|
|
|
|
with torch.inference_mode():
|
|
outputs = model(images)
|
|
|
|
for image, img_name, output in zip(images, img_names, outputs):
|
|
# x0, y0, x1, y1
|
|
|
|
yolo_labels = []
|
|
# for (x0, y0, x1, y1) in output['boxes'].cpu().numpy():
|
|
for Cbbox, score in zip(output['boxes'].cpu().numpy(), output['scores'].cpu().numpy()):
|
|
if score < detection_th:
|
|
continue
|
|
rel_x0 = Cbbox[0] / image.shape[-2]
|
|
rel_y0 = Cbbox[1] / image.shape[-2]
|
|
rel_x1 = Cbbox[2] / image.shape[-2]
|
|
rel_y1 = Cbbox[3] / image.shape[-2]
|
|
|
|
rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2
|
|
rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2
|
|
rel_width = rel_x1 - rel_x0
|
|
rel_height = rel_y1 - rel_y0
|
|
|
|
yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height, score])
|
|
|
|
if not figures_only:
|
|
label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt')
|
|
np.savetxt(label_path, yolo_labels)
|
|
|
|
plot_inference(image, img_name, output, detection_th, dataset_name)
|
|
|
|
|
|
|
|
def main(args):
|
|
model = create_model(num_classes=NUM_CLASSES)
|
|
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
model.to(DEVICE).eval()
|
|
|
|
inference_data = InferenceDataset(args.folder)
|
|
inference_loader = create_valid_loader(inference_data)
|
|
|
|
dataset_name = Path(args.folder).name
|
|
|
|
if not (Path(INFERENCE_OUTDIR)/dataset_name).exists():
|
|
Path(Path(INFERENCE_OUTDIR)/dataset_name).mkdir(parents=True, exist_ok=True)
|
|
|
|
infere_model(inference_loader, model, dataset_name, figures_only=args.figures_only)
|
|
|
|
if not args.figures_only:
|
|
if (Path('data').absolute() / dataset_name / 'file_dict.csv').exists():
|
|
(Path('data').absolute() / dataset_name / 'file_dict.csv').unlink()
|
|
|
|
|
|
|
|
# detection_threshold = 0.8
|
|
# frame_count = 0
|
|
# total_fps = 0
|
|
# test_images = glob.glob(f"{TRAIN_DIR}/*.png")
|
|
|
|
# for i in tqdm(np.arange(len(test_images))):
|
|
# image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
|
|
#
|
|
# img = Image.open(test_images[i])
|
|
# img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
|
|
#
|
|
# with torch.inference_mode():
|
|
# outputs = model(img_tensor.to(DEVICE))
|
|
#
|
|
# print(len(outputs[0]['boxes']))
|
|
|
|
# show_sample(img_tensor, outputs, detection_threshold)
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
|
parser.add_argument('folder', type=str, help='folder to infer picutes', default='')
|
|
parser.add_argument('-f', '--figures_only', action='store_true', help='only generate figures. keek possible existing labels')
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|