inference.py will be rewritten to really infere images without csv files

This commit is contained in:
Till Raab 2023-10-27 09:33:47 +02:00
parent cc6e97c2c8
commit 1042e44f37
2 changed files with 10 additions and 10 deletions

View File

@ -24,8 +24,8 @@ Use the script **./data/train_test_split.py** to split the original .csv file in
training and one for testing (both also stored in ./data/dataset).
### ToDos:
* FIX: name of generated png images. HINT {XXX:6.0f}.replace(' ', '0')
* transfere images from ./data/train to ./data/dataset
* FIX: name of generated png images. HINT: {XXX:6.0f}.replace(' ', '0')
* on a long scale: only save raw file bounding boxes in frequency and time (t0, t1, f0, f1) and the hyperparameters of the corresponding spectrogram. USE THESE PARAMETERS IN DATASET_FN.
## model.py
@ -51,6 +51,13 @@ im2 = ImageOps.grayscale(im1)
* check other pretrained models from torchvision.models.detection, e.g. fasterrcnn_resnet50_fpn_v2
## dataset.py
Contains custom datasets and dataloader. These are based on the images that are stored in
./data/dataset.
### ToDos:
* load/compute spectrogram directly and perform signal detection. E.g. spectrogram calculation as part of __getitem__
## config.py
Containes Hyperparameters used by the scripts.

View File

@ -17,7 +17,7 @@ import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
def plot_inference(img_tensor, img_name, output, target, detection_threshold):
def plot_inference(img_tensor, img_name, output, detection_threshold):
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
@ -35,13 +35,6 @@ def plot_inference(img_tensor, img_name, output, target, detection_threshold):
(y1 - y0),
fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10)
)
for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']):
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="white", linewidth=2, zorder=9)
)
ax.set_axis_off()
plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)