structure

This commit is contained in:
Till Raab 2023-10-20 08:57:37 +02:00
parent fe9886bb9a
commit cd3a066427
3 changed files with 46 additions and 3 deletions

16
confic.py Normal file
View File

@ -0,0 +1,16 @@
import torch
BATCH_SIZE = 4
RESIZE_TO = 416
NUM_EPOCHS = 10
NUM_WORKERS = 4
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
TRAIN_DIR = 'data/train'
CLASSES = ['__backgroud__', '1']
NUM_CLASSES = len(CLASSES)
OUTDIR = 'model_outputs'

View File

@ -10,7 +10,6 @@ from pathlib import Path
from tqdm.auto import tqdm
import itertools
import sys
import os
@ -63,7 +62,7 @@ def main(folder):
s_trans = transformed(log_s.unsqueeze(0))
fig_title = (f'{Path(folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:.0f}-{f1:.0f}Hz').replace(' ', '0')
fig = plt.figure(figsize=(10, 7), num=fig_title)
fig = plt.figure(figsize=(7, 7), num=fig_title)
gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0)# , bottom=0, left=0, right=1, top=1
gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)#
ax = fig.add_subplot(gs2[0, 0])
@ -74,7 +73,7 @@ def main(folder):
# fig.colorbar(im, cax=cax)
ax.axis(False)
plt.savefig(fig_title + '.png', dpi=300)
plt.savefig(fig_title + '.png', dpi=256)
plt.close()
# # ax.imshow(spec[f0:f1, t0:t1], cmap='gray')

28
model.py Normal file
View File

@ -0,0 +1,28 @@
import torch.nn
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def create_model(num_classes: int) -> torch.nn.Module:
"""
Create a pretrained Faster RCNN Model and replaces the final predictor in order to fit
to a specific detection task.
Parameters
----------
num_classes: int
Number of classes (+1) that shall be detected with the model.
One more class is required because of background.
Returns
-------
model: torch.nn.Module
Adapted FasterRCNN Model
"""
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model