structure
This commit is contained in:
parent
fe9886bb9a
commit
cd3a066427
16
confic.py
Normal file
16
confic.py
Normal file
@ -0,0 +1,16 @@
|
||||
import torch
|
||||
|
||||
BATCH_SIZE = 4
|
||||
RESIZE_TO = 416
|
||||
NUM_EPOCHS = 10
|
||||
NUM_WORKERS = 4
|
||||
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
TRAIN_DIR = 'data/train'
|
||||
|
||||
CLASSES = ['__backgroud__', '1']
|
||||
|
||||
NUM_CLASSES = len(CLASSES)
|
||||
|
||||
OUTDIR = 'model_outputs'
|
@ -10,7 +10,6 @@ from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import itertools
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
@ -63,7 +62,7 @@ def main(folder):
|
||||
s_trans = transformed(log_s.unsqueeze(0))
|
||||
|
||||
fig_title = (f'{Path(folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:.0f}-{f1:.0f}Hz').replace(' ', '0')
|
||||
fig = plt.figure(figsize=(10, 7), num=fig_title)
|
||||
fig = plt.figure(figsize=(7, 7), num=fig_title)
|
||||
gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0)# , bottom=0, left=0, right=1, top=1
|
||||
gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)#
|
||||
ax = fig.add_subplot(gs2[0, 0])
|
||||
@ -74,7 +73,7 @@ def main(folder):
|
||||
# fig.colorbar(im, cax=cax)
|
||||
ax.axis(False)
|
||||
|
||||
plt.savefig(fig_title + '.png', dpi=300)
|
||||
plt.savefig(fig_title + '.png', dpi=256)
|
||||
plt.close()
|
||||
# # ax.imshow(spec[f0:f1, t0:t1], cmap='gray')
|
||||
|
||||
|
28
model.py
Normal file
28
model.py
Normal file
@ -0,0 +1,28 @@
|
||||
import torch.nn
|
||||
import torchvision
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
|
||||
def create_model(num_classes: int) -> torch.nn.Module:
|
||||
"""
|
||||
Create a pretrained Faster RCNN Model and replaces the final predictor in order to fit
|
||||
to a specific detection task.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes: int
|
||||
Number of classes (+1) that shall be detected with the model.
|
||||
One more class is required because of background.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model: torch.nn.Module
|
||||
Adapted FasterRCNN Model
|
||||
"""
|
||||
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
|
||||
weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
|
||||
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
||||
|
||||
return model
|
Loading…
Reference in New Issue
Block a user