From cd3a066427a9f3fbd0cf745d594fc14b1dd70b06 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Fri, 20 Oct 2023 08:57:37 +0200 Subject: [PATCH] structure --- confic.py | 16 ++++++++++++++++ data/generate_dataset.py | 5 ++--- model.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 confic.py create mode 100644 model.py diff --git a/confic.py b/confic.py new file mode 100644 index 0000000..acb0397 --- /dev/null +++ b/confic.py @@ -0,0 +1,16 @@ +import torch + +BATCH_SIZE = 4 +RESIZE_TO = 416 +NUM_EPOCHS = 10 +NUM_WORKERS = 4 + +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +TRAIN_DIR = 'data/train' + +CLASSES = ['__backgroud__', '1'] + +NUM_CLASSES = len(CLASSES) + +OUTDIR = 'model_outputs' \ No newline at end of file diff --git a/data/generate_dataset.py b/data/generate_dataset.py index 68ff367..7631203 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -10,7 +10,6 @@ from pathlib import Path from tqdm.auto import tqdm import itertools - import sys import os @@ -63,7 +62,7 @@ def main(folder): s_trans = transformed(log_s.unsqueeze(0)) fig_title = (f'{Path(folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:.0f}-{f1:.0f}Hz').replace(' ', '0') - fig = plt.figure(figsize=(10, 7), num=fig_title) + fig = plt.figure(figsize=(7, 7), num=fig_title) gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0)# , bottom=0, left=0, right=1, top=1 gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)# ax = fig.add_subplot(gs2[0, 0]) @@ -74,7 +73,7 @@ def main(folder): # fig.colorbar(im, cax=cax) ax.axis(False) - plt.savefig(fig_title + '.png', dpi=300) + plt.savefig(fig_title + '.png', dpi=256) plt.close() # # ax.imshow(spec[f0:f1, t0:t1], cmap='gray') diff --git a/model.py b/model.py new file mode 100644 index 0000000..9804b6e --- /dev/null +++ b/model.py @@ -0,0 +1,28 @@ +import torch.nn +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor + +def create_model(num_classes: int) -> torch.nn.Module: + """ + Create a pretrained Faster RCNN Model and replaces the final predictor in order to fit + to a specific detection task. + + Parameters + ---------- + num_classes: int + Number of classes (+1) that shall be detected with the model. + One more class is required because of background. + + Returns + ------- + model: torch.nn.Module + Adapted FasterRCNN Model + """ + model = torchvision.models.detection.fasterrcnn_resnet50_fpn( + weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT) + + in_features = model.roi_heads.box_predictor.cls_score.in_features + + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) + + return model \ No newline at end of file