training loop runs ... no feed it
This commit is contained in:
parent
3f66b4a39a
commit
30a9e71e76
@ -99,12 +99,19 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
|
||||
Crise_size = rise_size_oi[enu]
|
||||
Cblf = closest_baseline_freq[enu]
|
||||
|
||||
rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
|
||||
rise_end_t = times_v[(times_v > Ct_oi) &
|
||||
(fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
|
||||
if len(rise_end_t) == 0:
|
||||
right_time_bound[enu] = np.nan
|
||||
else:
|
||||
right_time_bound[enu] = rise_end_t[0]
|
||||
|
||||
mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.))
|
||||
left_time_bound = left_time_bound[mask]
|
||||
right_time_bound = right_time_bound[mask]
|
||||
lower_freq_bound = lower_freq_bound[mask]
|
||||
upper_freq_bound = upper_freq_bound[mask]
|
||||
|
||||
dt_bbox = right_time_bound - left_time_bound
|
||||
df_bbox = upper_freq_bound - lower_freq_bound
|
||||
left_time_bound -= dt_bbox * 0.1
|
||||
@ -154,7 +161,7 @@ def main(args):
|
||||
for f in pd.unique(bbox_df['image']):
|
||||
eval_files.append(f.split('__')[0])
|
||||
|
||||
folders = [args.folders]
|
||||
folders = [args.folder]
|
||||
|
||||
for enu, folder in enumerate(folders):
|
||||
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
|
||||
|
12
datasets.py
12
datasets.py
@ -39,11 +39,19 @@ class CustomDataset(Dataset):
|
||||
Cbbox = self.bbox_df[self.bbox_df['image'] == image_name]
|
||||
|
||||
labels = np.ones(len(Cbbox), dtype=int)
|
||||
boxes = Cbbox.loc[:, ['x0', 'x1', 'y0', 'y1']].values
|
||||
boxes = torch.as_tensor(Cbbox.loc[:, ['x0', 'y0', 'x1', 'y1']].values, dtype=torch.float32)
|
||||
|
||||
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
||||
# no crowd instances
|
||||
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["labels"] = labels
|
||||
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
|
||||
target["area"] = area
|
||||
target["iscrowd"] = iscrowd
|
||||
image_id = torch.tensor([idx])
|
||||
target["image_id"] = image_id
|
||||
|
||||
return img_tensor, target
|
||||
|
||||
|
2
model.py
2
model.py
@ -23,6 +23,6 @@ def create_model(num_classes: int) -> torch.nn.Module:
|
||||
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)
|
||||
|
||||
return model
|
42
train.py
42
train.py
@ -1,3 +1,43 @@
|
||||
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS)
|
||||
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, TRAIN_DIR)
|
||||
from model import create_model
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from datasets import create_train_test_dataset, create_train_loader, create_valid_loader
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
|
||||
from IPython import embed
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_data, test_data = create_train_test_dataset(TRAIN_DIR)
|
||||
train_loader = create_train_loader(train_data)
|
||||
test_loader = create_train_loader(test_data)
|
||||
|
||||
model = create_model(num_classes=1)
|
||||
model = model.to(DEVICE)
|
||||
|
||||
params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
prog_bar = tqdm(train_loader, total=len(train_loader))
|
||||
for samples, targets in prog_bar:
|
||||
images = list(image.to(DEVICE) for image in samples)
|
||||
|
||||
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
|
||||
try:
|
||||
loss_dict = model(images, targets)
|
||||
except:
|
||||
embed()
|
||||
quit()
|
||||
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
loss_value = losses.item()
|
||||
|
||||
optimizer.zero_grad()
|
||||
losses.backward()
|
||||
optimizer.step()
|
||||
|
||||
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
|
Loading…
Reference in New Issue
Block a user