Compare commits
17 Commits
datastruct
...
master
Author | SHA1 | Date | |
---|---|---|---|
573665d3f1 | |||
8cd422461a | |||
cc4408d75a | |||
22b05aec76 | |||
59ffd101f1 | |||
5c74ac463b | |||
866fd3081d | |||
f70e74f5e1 | |||
93269a96a1 | |||
0f0128439e | |||
324c690841 | |||
3ed6f1102c | |||
3f235f02f3 | |||
33f6bb8683 | |||
f286156607 | |||
397be5c51b | |||
c08ede1e76 |
@ -6,6 +6,8 @@ spectrogram images. The model itself is a pretrained **FasterRCNN** Model with a
|
||||
**ResNet50** Backbone. Only the final predictor is replaced to not predict the 91 classes
|
||||
present in the coco-dataset the model is trained to but the (currently) 1 category it should detect.
|
||||
|
||||
HINT: Trained datasets can be found at raab@polarbear:projects/efishSignalDetector/data/rise_training
|
||||
|
||||
## Long-Term and major ToDos:
|
||||
* implement gui to correct bounding boxes
|
||||
* implement reinforced learning
|
||||
|
@ -26,8 +26,8 @@ DELTA_TIME = 60*10
|
||||
TIME_OVERLAP = 60*1
|
||||
|
||||
# output parameters
|
||||
DATA_DIR = 'data/images'
|
||||
LABEL_DIR = 'data/labels'
|
||||
DATA_DIR = 'data/rise_training/images'
|
||||
LABEL_DIR = 'data/rise_training/labels'
|
||||
OUTDIR = 'model_outputs'
|
||||
INFERENCE_OUTDIR = 'inference_outputs'
|
||||
for required_folders in [DATA_DIR, OUTDIR, INFERENCE_OUTDIR, LABEL_DIR]:
|
||||
|
308
correct_bboxes.py
Normal file
308
correct_bboxes.py
Normal file
@ -0,0 +1,308 @@
|
||||
import os
|
||||
import sys
|
||||
from PyQt5.QtWidgets import *
|
||||
from PyQt5.QtCore import *
|
||||
import pyqtgraph as pg
|
||||
import shutil
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from IPython import embed
|
||||
import pathlib
|
||||
|
||||
NONE_CHECKED_CHECKER_COLORS = ['red', 'black']
|
||||
|
||||
class ImageWithBbox(pg.GraphicsLayoutWidget):
|
||||
# class ImageWithBbox(pg.ImageView):
|
||||
def __init__(self, img_path, parent=None):
|
||||
super(ImageWithBbox, self).__init__(parent)
|
||||
self.img_path = img_path
|
||||
self.label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt')
|
||||
|
||||
# image setup
|
||||
self.imgItem = pg.ImageItem(ColorMap='viridis')
|
||||
self.plot_widget = self.addPlot(title="")
|
||||
self.plot_widget.addItem(self.imgItem, colorMap='viridis')
|
||||
self.plot_widget.scene().sigMouseClicked.connect(self.addROI)
|
||||
|
||||
# self.imgItem.mouseClickEvent()
|
||||
|
||||
self.imgItem.getViewBox().invertY(True)
|
||||
self.imgItem.getViewBox().setMouseEnabled(x=False, y=False)
|
||||
|
||||
# get image and labels
|
||||
img = Image.open(img_path)
|
||||
img_gray = ImageOps.grayscale(img)
|
||||
self.img_array = np.array(img_gray).T
|
||||
self.imgItem.setImage(np.array(self.img_array))
|
||||
self.plot_widget.setYRange(0, self.img_array.shape[0], padding=0)
|
||||
self.plot_widget.setXRange(0, self.img_array.shape[1], padding=0)
|
||||
|
||||
# label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt')
|
||||
self.labels = np.loadtxt(self.label_path, delimiter=' ')
|
||||
if len(self.labels) > 0 and len(self.labels.shape) == 1:
|
||||
self.labels = np.expand_dims(self.labels, 0)
|
||||
|
||||
# add ROIS
|
||||
self.ROIs = []
|
||||
for enu, l in enumerate(self.labels):
|
||||
# x_center, y_center, width, height = l[1:] * self.img_array.shape[1]
|
||||
x_center = l[1] * self.img_array.shape[0]
|
||||
y_center = l[2] * self.img_array.shape[1]
|
||||
width = l[3] * self.img_array.shape[0]
|
||||
height = l[4] * self.img_array.shape[1]
|
||||
|
||||
x0, y0, = x_center-width/2, y_center-height/2
|
||||
ROI = pg.RectROI((x0, y0), size=(width, height), removable=True, sideScalers=True)
|
||||
ROI.sigRemoveRequested.connect(self.removeROI)
|
||||
|
||||
self.ROIs.append(ROI)
|
||||
self.plot_widget.addItem(ROI)
|
||||
|
||||
def removeROI(self, roi):
|
||||
if roi in self.ROIs:
|
||||
self.ROIs.remove(roi)
|
||||
self.plot_widget.removeItem(roi)
|
||||
|
||||
def addROI(self, event):
|
||||
# Check if the event is a double-click event
|
||||
if event.double():
|
||||
pos = event.pos()
|
||||
# Transform the mouse position to data coordinates
|
||||
pos_data = self.plot_widget.getViewBox().mapToView(pos)
|
||||
x, y = pos_data.x(), pos_data.y()
|
||||
# Create a new ROI at the double-clicked position
|
||||
new_ROI = pg.RectROI(pos=(x, y), size=(self.img_array.shape[0]*0.05, self.img_array.shape[1]*0.05), removable=True)
|
||||
new_ROI.sigRemoveRequested.connect(self.removeROI)
|
||||
self.ROIs.append(new_ROI)
|
||||
self.plot_widget.addItem(new_ROI)
|
||||
|
||||
class Bbox_correct_UI(QMainWindow):
|
||||
def __init__(self, data_path, parent=None):
|
||||
super(Bbox_correct_UI, self).__init__(parent)
|
||||
|
||||
self.data_path = data_path
|
||||
self.files = sorted(list(pathlib.Path(self.data_path).absolute().rglob('*images/*.png')))
|
||||
|
||||
self.close_without_saving_rois = False
|
||||
|
||||
rec = QApplication.desktop().screenGeometry()
|
||||
height = rec.height()
|
||||
width = rec.width()
|
||||
self.resize(int(.8 * width), int(.8 * height))
|
||||
|
||||
|
||||
# widget and layout
|
||||
self.central_widget = QWidget(self)
|
||||
self.central_gridLayout = QGridLayout()
|
||||
self.central_Layout = QHBoxLayout()
|
||||
self.central_widget.setLayout(self.central_Layout)
|
||||
self.setCentralWidget(self.central_widget)
|
||||
|
||||
###########
|
||||
self.highlighted_label = None
|
||||
self.all_labels = []
|
||||
self.load_or_create_file_dict()
|
||||
|
||||
new_name, new_file, new_checked = self.file_dict.iloc[0].values
|
||||
self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')
|
||||
|
||||
# image widget
|
||||
self.current_img = ImageWithBbox(new_file, parent=self)
|
||||
self.central_Layout.addWidget(self.current_img, 4)
|
||||
|
||||
# image select widget
|
||||
self.scroll = QScrollArea() # Scroll Area which contains the widgets, set as the centralWidget
|
||||
self.widget = QWidget() # Widget that contains the collection of Vertical Box
|
||||
self.vbox = QVBoxLayout()
|
||||
|
||||
for i in range(len(self.file_dict)):
|
||||
label = QLabel(f'{self.file_dict["name"][i]}')
|
||||
# label.setFrameShape(QLabel.Panel)
|
||||
# label.setFrameShadow(QLabel.Sunken)
|
||||
label.setAlignment(Qt.AlignRight)
|
||||
label.mousePressEvent = lambda event, label=label: self.label_clicked(label)
|
||||
# label.mousePressEvent = lambda event, label: self.label_clicked(label)
|
||||
|
||||
if i == 0:
|
||||
label.setStyleSheet("border: 2px solid black; "
|
||||
"color : %s;" % (NONE_CHECKED_CHECKER_COLORS[self.file_dict['checked'][i]]))
|
||||
self.highlighted_label = label
|
||||
else:
|
||||
label.setStyleSheet("border: 1px solid gray; "
|
||||
"color : %s;" % (NONE_CHECKED_CHECKER_COLORS[self.file_dict['checked'][i]]))
|
||||
self.vbox.addWidget(label)
|
||||
self.all_labels.append(label)
|
||||
|
||||
self.widget.setLayout(self.vbox)
|
||||
|
||||
self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
|
||||
self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
||||
self.scroll.setWidgetResizable(True)
|
||||
self.scroll.setWidget(self.widget)
|
||||
|
||||
self.central_Layout.addWidget(self.scroll, 1)
|
||||
|
||||
self.add_actions()
|
||||
|
||||
self.init_MenuBar()
|
||||
|
||||
def load_or_create_file_dict(self):
|
||||
csvs_im_data_path = list(pathlib.Path(self.data_path).absolute().rglob('*file_dict.csv'))
|
||||
if len(csvs_im_data_path) == 0:
|
||||
self.file_dict = pd.DataFrame(
|
||||
{'name': [f.name for f in self.files],
|
||||
'files': self.files,
|
||||
'checked': np.zeros(len(self.files), dtype=int)} # change this to locked
|
||||
)
|
||||
else:
|
||||
self.file_dict = pd.read_csv(csvs_im_data_path[0], sep=',')
|
||||
|
||||
def save_file_dict(self):
|
||||
self.file_dict.to_csv(pathlib.Path(self.data_path)/'file_dict.csv', sep=',', index=False)
|
||||
|
||||
def label_clicked(self, clicked_label):
|
||||
if self.highlighted_label:
|
||||
hl_mask = self.file_dict['name'] == self.highlighted_label.text()
|
||||
hl_label_name, _, hl_checked = self.file_dict[hl_mask].values[0]
|
||||
self.highlighted_label.setStyleSheet("border: 1px solid gray; "
|
||||
"color: %s;" % (NONE_CHECKED_CHECKER_COLORS[hl_checked]))
|
||||
|
||||
mask = self.file_dict['name'] == clicked_label.text()
|
||||
new_name, new_file, new_checked = self.file_dict[mask].values[0]
|
||||
clicked_label.setStyleSheet("border: 2px solid black;"
|
||||
"color: %s;" % (NONE_CHECKED_CHECKER_COLORS[new_checked]))
|
||||
self.highlighted_label = clicked_label
|
||||
|
||||
self.switch_to_new_file(new_file, new_name)
|
||||
|
||||
def lock_file(self):
|
||||
|
||||
hl_mask = self.file_dict['name'] == self.highlighted_label.text()
|
||||
hl_label_name, _, hl_checked = self.file_dict[hl_mask].values[0]
|
||||
|
||||
# ToDo: do everything with the index instead of mask
|
||||
df_idx = self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
|
||||
self.file_dict.at[df_idx, 'checked'] = 1
|
||||
|
||||
self.highlighted_label.setStyleSheet("border: 1px solid gray; "
|
||||
"color: 'black';")
|
||||
|
||||
new_idx = df_idx + 1 if df_idx < len(self.file_dict)-1 else 0
|
||||
new_name, new_file, new_checked = self.file_dict.iloc[new_idx].values
|
||||
|
||||
self.all_labels[new_idx].setStyleSheet("border: 2px solid black;"
|
||||
"color: %s;" % (NONE_CHECKED_CHECKER_COLORS[new_checked]))
|
||||
self.highlighted_label = self.all_labels[new_idx]
|
||||
|
||||
self.switch_to_new_file(new_file, new_name)
|
||||
|
||||
|
||||
def switch_to_new_file(self, new_file, new_name):
|
||||
self.readout_rois()
|
||||
|
||||
self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')
|
||||
|
||||
self.central_Layout.removeWidget(self.current_img)
|
||||
self.current_img = ImageWithBbox(new_file, parent=self)
|
||||
self.central_Layout.insertWidget(0, self.current_img, 4)
|
||||
|
||||
def readout_rois(self):
|
||||
if self.close_without_saving_rois:
|
||||
return
|
||||
new_labels = []
|
||||
for roi in self.current_img.ROIs:
|
||||
x0, y0 = roi.pos()
|
||||
x0 /= self.current_img.img_array.shape[0]
|
||||
y0 /= self.current_img.img_array.shape[1]
|
||||
|
||||
w, h = roi.size()
|
||||
w /= self.current_img.img_array.shape[0]
|
||||
h /= self.current_img.img_array.shape[1]
|
||||
|
||||
x_center = x0 + w/2
|
||||
y_center = y0 + h/2
|
||||
new_labels.append([1, x_center, y_center, w, h])
|
||||
new_labels = np.array(new_labels)
|
||||
np.savetxt(self.current_img.label_path, new_labels)
|
||||
|
||||
def export_validated_data(self):
|
||||
fd = QFileDialog()
|
||||
export_path = fd.getExistingDirectory(self, 'Select Directory')
|
||||
if export_path:
|
||||
export_idxs = list(self.file_dict['files'][self.file_dict['checked'] == 1].index)
|
||||
keep_idxs = []
|
||||
for export_file_path, export_idx in zip(list(self.file_dict['files'][self.file_dict['checked'] == 1]), export_idxs):
|
||||
export_image_path = pathlib.Path(export_file_path)
|
||||
export_label_path = export_image_path.parent.parent / 'labels' / pathlib.Path(export_image_path.name).with_suffix('.txt')
|
||||
|
||||
target_image_path = pathlib.Path(export_path) / 'images' / export_image_path.name
|
||||
target_label_path = pathlib.Path(export_path) / 'labels' / export_label_path.name
|
||||
if not target_image_path.exists():
|
||||
# ToDo: this is not tested but should work
|
||||
# os.rename(export_image_path, target_image_path)
|
||||
shutil.copy(export_image_path, target_image_path)
|
||||
# os.rename(export_label_path, target_label_path)
|
||||
shutil.copy(export_label_path, target_label_path)
|
||||
else:
|
||||
keep_idxs.append(export_idx)
|
||||
|
||||
|
||||
# self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
|
||||
drop_idxs = list(set(export_idxs) - set(keep_idxs))
|
||||
# self.file_dict = self.file_dict.drop(drop_idxs)
|
||||
self.save_file_dict()
|
||||
self.close_without_saving_rois = True
|
||||
self.close()
|
||||
|
||||
else:
|
||||
print('nope')
|
||||
pass
|
||||
|
||||
def open(self):
|
||||
pass
|
||||
|
||||
def init_MenuBar(self):
|
||||
menubar = self.menuBar() # needs QMainWindow ?!
|
||||
file = menubar.addMenu('&File') # create file menu ... accessable with alt+F
|
||||
file.addActions([self.Act_open, self.Act_export, self.Act_exit])
|
||||
|
||||
edit = menubar.addMenu('&Help')
|
||||
# edit.addActions([self.Act_undo, self.Act_unassigned_funds])
|
||||
|
||||
def add_actions(self):
|
||||
self.lock_file_act = QAction('loc', self)
|
||||
self.lock_file_act.triggered.connect(self.lock_file)
|
||||
self.lock_file_act.setShortcut(Qt.Key_Space)
|
||||
self.addAction(self.lock_file_act)
|
||||
|
||||
self.Act_open = QAction('&Open', self)
|
||||
# self.Act_open.setStatusTip('Open file')
|
||||
self.Act_open.triggered.connect(self.open) # ToDo: implement this fn
|
||||
|
||||
self.Act_export = QAction('&Export', self)
|
||||
# self.Act_export.setStatusTip('Open file')
|
||||
self.Act_export.triggered.connect(self.export_validated_data)
|
||||
|
||||
self.Act_exit = QAction('&Exit', self) # trigger with alt+E
|
||||
self.Act_exit.setShortcut(Qt.Key_Q)
|
||||
self.Act_exit.triggered.connect(self.close)
|
||||
|
||||
def closeEvent(self, *args, **kwargs):
|
||||
super(QMainWindow, self).closeEvent(*args, **kwargs)
|
||||
self.readout_rois()
|
||||
self.save_file_dict()
|
||||
|
||||
def main_UI():
|
||||
app = QApplication(sys.argv) # create application
|
||||
data_path = sys.argv[1]
|
||||
w = Bbox_correct_UI(data_path) # create window
|
||||
w.show()
|
||||
sys.exit(app.exec_())
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_UI()
|
73
datasets.py
73
datasets.py
@ -26,6 +26,8 @@ class InferenceDataset(Dataset):
|
||||
def __init__(self, dir_path):
|
||||
self.dir_path = dir_path
|
||||
self.all_images = sorted(list(Path(self.dir_path).rglob(f'*.png')))
|
||||
self.file_names = np.array([Path(x).with_suffix('') for x in self.all_images])
|
||||
# self.images = np.array(sorted(os.listdir(DATA_DIR)))
|
||||
def __len__(self):
|
||||
return len(self.all_images)
|
||||
|
||||
@ -39,45 +41,6 @@ class InferenceDataset(Dataset):
|
||||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
def __init__(self, dir_path, bbox_df):
|
||||
self.dir_path = dir_path
|
||||
self.bbox_df = bbox_df
|
||||
|
||||
self.all_images = np.array(sorted(pd.unique(self.bbox_df['image'])), dtype=str)
|
||||
self.image_paths = list(map(lambda x: Path(self.dir_path)/x, self.all_images))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image_name = self.all_images[idx]
|
||||
image_path = os.path.join(self.dir_path, image_name)
|
||||
|
||||
img = Image.open(image_path)
|
||||
img_tensor = F.to_tensor(img.convert('RGB'))
|
||||
|
||||
Cbbox = self.bbox_df[self.bbox_df['image'] == image_name]
|
||||
|
||||
labels = np.ones(len(Cbbox), dtype=int)
|
||||
boxes = torch.as_tensor(Cbbox.loc[:, ['x0', 'y0', 'x1', 'y1']].values, dtype=torch.float32)
|
||||
|
||||
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
||||
# no crowd instances
|
||||
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
|
||||
target["area"] = area
|
||||
target["iscrowd"] = iscrowd
|
||||
image_id = torch.tensor([idx])
|
||||
target["image_id"] = image_id
|
||||
target["image_name"] = image_name #ToDo: implement this as 3rd return...
|
||||
|
||||
return img_tensor, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_images)
|
||||
|
||||
|
||||
class CustomDataset_v2(Dataset):
|
||||
def __init__(self, limited_idxs=None):
|
||||
self.images = np.array(sorted(os.listdir(DATA_DIR)))
|
||||
self.labels = np.array(sorted(os.listdir(LABEL_DIR)))
|
||||
@ -99,7 +62,6 @@ class CustomDataset_v2(Dataset):
|
||||
|
||||
boxes, labels, area, iscrowd = self.extract_bboxes(annotations)
|
||||
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
|
||||
@ -147,27 +109,11 @@ def custom_train_test_split():
|
||||
train_idxs = np.sort(data_idxs[int(0.2 * len(data_idxs)):])
|
||||
test_idxs = np.sort(data_idxs[:int(0.2 * len(data_idxs))])
|
||||
|
||||
train_data = CustomDataset_v2(limited_idxs=train_idxs)
|
||||
test_data = CustomDataset_v2(limited_idxs=test_idxs)
|
||||
train_data = CustomDataset(limited_idxs=train_idxs)
|
||||
test_data = CustomDataset(limited_idxs=test_idxs)
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
def create_train_or_test_dataset(path, train=True):
|
||||
if train == True:
|
||||
pfx='train'
|
||||
print('Generate train dataset !')
|
||||
else:
|
||||
print('Generate test dataset !')
|
||||
pfx='test'
|
||||
|
||||
csv_candidates = list(Path(path).rglob(f'*{pfx}*.csv'))
|
||||
if len(csv_candidates) == 0:
|
||||
print(f'no .csv files for *{pfx}* found in {Path(path)}')
|
||||
quit()
|
||||
else:
|
||||
bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0)
|
||||
return CustomDataset(path, bboxes)
|
||||
|
||||
|
||||
def create_train_loader(train_dataset, num_workers=0):
|
||||
train_loader = DataLoader(
|
||||
@ -191,17 +137,6 @@ def create_valid_loader(valid_dataset, num_workers=0):
|
||||
)
|
||||
return valid_loader
|
||||
|
||||
def create_inference_loader(inference_dataset, num_workers=0):
|
||||
inference_loader = DataLoader(
|
||||
inference_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return inference_loader
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_data, test_data = custom_train_test_split()
|
||||
|
||||
|
347
extract_from_bbox.py
Normal file
347
extract_from_bbox.py
Normal file
@ -0,0 +1,347 @@
|
||||
import itertools
|
||||
import pathlib
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from IPython import embed
|
||||
|
||||
def extract_time_freq_range_from_filename(img_path):
|
||||
file_name_str, time_span_str, freq_span_str = str(img_path.with_suffix('').name).split('__')
|
||||
time_span_str = time_span_str.replace('s', '')
|
||||
freq_span_str = freq_span_str.replace('Hz', '')
|
||||
|
||||
t0, t1 = np.array(time_span_str.split('-'), dtype=float)
|
||||
f0, f1 = np.array(freq_span_str.split('-'), dtype=float)
|
||||
return file_name_str, t0, t1, f0, f1
|
||||
|
||||
def bbox_to_data(img_path, t_min, t_max, f_min, f_max):
|
||||
label_path = img_path.parent.parent / 'labels' / img_path.with_suffix('.txt').name
|
||||
|
||||
annotations = np.loadtxt(label_path, delimiter=' ')
|
||||
if len(annotations.shape) == 1:
|
||||
annotations = np.array([annotations])
|
||||
if annotations.shape[1] == 0:
|
||||
print('no rises detected in this window')
|
||||
return [], []
|
||||
|
||||
|
||||
boxes = np.array([[x[1] - x[3] / 2, 1 - (x[2] + x[4] / 2), x[1] + x[3] / 2, 1 - (x[2] - x[4] / 2)] for x in annotations]) # x0, y0, x1, y1
|
||||
boxes[:, 0] = boxes[:, 0] * (t_max - t_min) + t_min
|
||||
boxes[:, 2] = boxes[:, 2] * (t_max - t_min) + t_min
|
||||
|
||||
boxes[:, 1] = boxes[:, 1] * (f_max - f_min) + f_min
|
||||
boxes[:, 3] = boxes[:, 3] * (f_max - f_min) + f_min
|
||||
|
||||
scores = annotations[:, 5]
|
||||
|
||||
return boxes, scores
|
||||
|
||||
def load_wavetracker_data(raw_path):
|
||||
fund_v = np.load(raw_path.parent / 'fund_v.npy')
|
||||
ident_v = np.load(raw_path.parent / 'ident_v.npy')
|
||||
idx_v = np.load(raw_path.parent / 'idx_v.npy')
|
||||
times = np.load(raw_path.parent / 'times.npy')
|
||||
|
||||
return fund_v, ident_v, idx_v, times
|
||||
|
||||
|
||||
def assign_rises_to_ids(raw_path, time_frequency_bboxes, bbox_groups):
|
||||
def identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v):
|
||||
|
||||
mean_id_box_f_rel_to_bbox = []
|
||||
for id in possible_ids:
|
||||
id_box_f = fund_v[(ident_v == id) & (times[idx_v] >= t0) & (times[idx_v] <= t1)]
|
||||
id_box_f_rel_to_bbox = (id_box_f - f0) / (f1 - f0)
|
||||
mean_id_box_f_rel_to_bbox.append(np.mean(id_box_f_rel_to_bbox))
|
||||
# print(id, np.mean(id_box_f), f0, f1, np.mean(id_box_f_rel_to_bbox))
|
||||
|
||||
most_likely_id = possible_ids[np.argsort(mean_id_box_f_rel_to_bbox)[0]]
|
||||
return most_likely_id
|
||||
|
||||
fund_v, ident_v, idx_v, times = load_wavetracker_data(raw_path)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(times[idx_v[~np.isnan(ident_v)]], fund_v[~np.isnan(ident_v)], '.')
|
||||
|
||||
mask = time_frequency_bboxes['file_name'] == raw_path.parent.name
|
||||
for index, bbox in time_frequency_bboxes[mask].iterrows():
|
||||
name, t0, f0, t1, f1, score = (bbox[0], *bbox[1:-2].astype(float))
|
||||
if bbox_groups[index] == 0:
|
||||
color = 'tab:green'
|
||||
elif bbox_groups[index] > 0:
|
||||
color = 'tab:red'
|
||||
else:
|
||||
color = 'k'
|
||||
|
||||
ax.add_patch(
|
||||
Rectangle((t0, f0),
|
||||
(t1 - t0),
|
||||
(f1 - f0),
|
||||
fill=False, color=color, linestyle='--', linewidth=2, zorder=10)
|
||||
|
||||
)
|
||||
ax.text(t1, f1, f'{score:.1%}', ha='right', va='bottom')
|
||||
|
||||
possible_ids = np.unique(
|
||||
ident_v[~np.isnan(ident_v) &
|
||||
(t0 <= times[idx_v]) &
|
||||
(t1 >= times[idx_v]) &
|
||||
(f0 <= fund_v) &
|
||||
(f1 >= fund_v)]
|
||||
)
|
||||
if len(possible_ids) == 1:
|
||||
assigned_id = possible_ids[0]
|
||||
time_frequency_bboxes.at[index, 'id'] = assigned_id
|
||||
elif len(possible_ids) > 1:
|
||||
assigned_id = identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v)
|
||||
time_frequency_bboxes.at[index, 'id'] = assigned_id
|
||||
# rise_id[index] = identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v)
|
||||
else:
|
||||
continue
|
||||
|
||||
rise_start_freq_th = f0 + (f1 - f0) * 0.37
|
||||
wavetracker_mask = np.arange(len(fund_v))[
|
||||
(times[idx_v] >= t0) &
|
||||
(times[idx_v] <= t1) &
|
||||
(ident_v == assigned_id)
|
||||
]
|
||||
if np.sum(fund_v[wavetracker_mask] > rise_start_freq_th) > 0:
|
||||
# rise start time = moment where rise freq exceeds 37% of bbox freq range ...
|
||||
rise_start_idx = wavetracker_mask[fund_v[wavetracker_mask] > rise_start_freq_th][0]
|
||||
rise_time = times[idx_v[rise_start_idx]]
|
||||
else:
|
||||
### if this is never the case use the largest slope
|
||||
rise_start_idx = wavetracker_mask[np.argmax(np.diff(fund_v[wavetracker_mask]))]
|
||||
rise_time = times[idx_v[rise_start_idx]]
|
||||
time_frequency_bboxes.at[index, 'event_time'] = rise_time
|
||||
|
||||
ax.plot(rise_time, fund_v[rise_start_idx], 'ok')
|
||||
# embed()
|
||||
# quit()
|
||||
# time_frequency_bboxes['id'] = rise_id
|
||||
# embed()
|
||||
# plt.show()
|
||||
plt.close()
|
||||
return time_frequency_bboxes
|
||||
|
||||
|
||||
def find_overlapping_bboxes(df_collect):
|
||||
file_names = np.array(df_collect)[:, 0]
|
||||
bboxes = np.array(df_collect)[:, 1:].astype(float)
|
||||
|
||||
overlap_bbox_idxs = []
|
||||
|
||||
for file_name in tqdm(np.unique(file_names)):
|
||||
file_bbox_idxs = np.arange(len(file_names))[file_names == file_name]
|
||||
for ind0, ind1 in itertools.combinations(file_bbox_idxs, r=2):
|
||||
bb0 = bboxes[ind0]
|
||||
bb1 = bboxes[ind1]
|
||||
t0_0, f0_0, t0_1, f0_1 = bb0[:-1]
|
||||
t1_0, f1_0, t1_1, f1_1 = bb1[:-1]
|
||||
|
||||
bb_times = np.array([t0_0, t0_1, t1_0, t1_1])
|
||||
bb_time_associate = np.array([0, 0, 1, 1])
|
||||
time_helper = bb_time_associate[np.argsort(bb_times)]
|
||||
|
||||
if time_helper[0] == time_helper[1]:
|
||||
# no temporal overlap
|
||||
continue
|
||||
|
||||
# check freq overlap
|
||||
bb_freqs = np.array([f0_0, f0_1, f1_0, f1_1])
|
||||
bb_freq_associate = np.array([0, 0, 1, 1])
|
||||
|
||||
freq_helper = bb_freq_associate[np.argsort(bb_freqs)]
|
||||
|
||||
if freq_helper[0] == freq_helper[1]:
|
||||
continue
|
||||
|
||||
overlap_bbox_idxs.append((ind0, ind1))
|
||||
|
||||
return np.asarray(overlap_bbox_idxs)
|
||||
|
||||
|
||||
def main(args):
|
||||
img_paths = sorted(list(pathlib.Path(args.annotations).absolute().rglob('*.png')))
|
||||
df_collect = []
|
||||
|
||||
for img_path in img_paths:
|
||||
# convert to time_frequency
|
||||
file_name_str, t_min, t_max, f_min, f_max = extract_time_freq_range_from_filename(img_path)
|
||||
|
||||
boxes, scores = bbox_to_data(img_path, t_min, t_max, f_min, f_max ) # t0, t1, f0, f1
|
||||
# store values in df
|
||||
if not len(boxes) == 0:
|
||||
for (t0, f0, t1, f1), s in zip(boxes, scores):
|
||||
df_collect.append([file_name_str, t0, f0, t1, f1, s])
|
||||
|
||||
df_collect = np.array(df_collect)
|
||||
|
||||
overlap_bbox_idxs = find_overlapping_bboxes(df_collect)
|
||||
|
||||
bbox_groups = delete_double_boxes(overlap_bbox_idxs, df_collect)
|
||||
|
||||
time_frequency_bboxes = pd.DataFrame(data= np.array(df_collect), columns=['file_name', 't0', 'f0', 't1', 'f1', 'score'])
|
||||
time_frequency_bboxes['id'] = np.full(len(time_frequency_bboxes), np.nan)
|
||||
time_frequency_bboxes['event_time'] = np.full(len(time_frequency_bboxes), np.nan)
|
||||
|
||||
###########################################
|
||||
# for file_name in time_frequency_bboxes['file_name'].unique():
|
||||
# fig, ax = plt.subplots()
|
||||
#
|
||||
# mask = time_frequency_bboxes['file_name'] == file_name
|
||||
# for index, bbox in time_frequency_bboxes[mask].iterrows():
|
||||
# name, t0, f0, t1, f1 = (bbox[0], *bbox[1:-1].astype(float))
|
||||
# if bbox_groups[index] == 0:
|
||||
# color = 'tab:green'
|
||||
# elif bbox_groups[index] > 0:
|
||||
# color = 'tab:red'
|
||||
# else:
|
||||
# color = 'k'
|
||||
#
|
||||
# ax.add_patch(
|
||||
# Rectangle((t0, f0),
|
||||
# (t1 - t0),
|
||||
# (f1 - f0),
|
||||
# fill=False, color=color, linestyle='--', linewidth=2, zorder=10)
|
||||
# )
|
||||
# # ax.set_xlim(float(time_frequency_bboxes[mask]['t0'].min()), float(time_frequency_bboxes[mask]['t1'].max()))
|
||||
# ax.set_xlim(0, float(time_frequency_bboxes[mask]['t1'].max()))
|
||||
# # ax.set_ylim(float(time_frequency_bboxes[mask]['f0'].min()), float(time_frequency_bboxes[mask]['f1'].max()))
|
||||
# ax.set_ylim(400, 1200)
|
||||
# plt.show()
|
||||
# exit()
|
||||
###########################################
|
||||
|
||||
if args.tracking_data_path:
|
||||
file_paths = sorted(list(pathlib.Path(args.tracking_data_path).absolute().rglob('*.raw')))
|
||||
for raw_path in file_paths:
|
||||
if not raw_path.parent.name in time_frequency_bboxes['file_name'].to_list():
|
||||
continue
|
||||
time_frequency_bboxes = assign_rises_to_ids(raw_path, time_frequency_bboxes, bbox_groups)
|
||||
for raw_path in file_paths:
|
||||
# mask = (time_frequency_bboxes['file_name'] == raw_path.parent.name)
|
||||
mask = ((time_frequency_bboxes['file_name'] == raw_path.parent.name) & (~np.isnan(time_frequency_bboxes['id'])))
|
||||
save_df = pd.DataFrame(time_frequency_bboxes[mask][['t0', 't1', 'f0', 'f1', 'score', 'id', 'event_time']].values, columns=['t0', 't1', 'f0', 'f1', 'score', 'id', 'event_time'])
|
||||
save_df['label'] = np.ones(len(save_df), dtype=int)
|
||||
save_df.to_csv(raw_path.parent / 'risedetector_bboxes.csv', sep = ',', index = False)
|
||||
quit()
|
||||
|
||||
def delete_double_boxes(overlap_bbox_idxs, df_collect, overlap_th = 0.2):
|
||||
def get_connected(non_regarded_bbox_idx, overlap_bbox_idxs):
|
||||
mask = np.array((np.array(overlap_bbox_idxs) == non_regarded_bbox_idx).sum(1), dtype=bool)
|
||||
affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask])
|
||||
return affected_bbox_idxs
|
||||
|
||||
handled_bbox_idxs = []
|
||||
bbox_groups = np.zeros(len(df_collect))
|
||||
# detele_bbox_idxs = []
|
||||
|
||||
for Coverlapping_bbox_idx in tqdm(np.unique(overlap_bbox_idxs)):
|
||||
if Coverlapping_bbox_idx in handled_bbox_idxs:
|
||||
continue
|
||||
|
||||
regarded_bbox_idxs = [Coverlapping_bbox_idx]
|
||||
mask = np.array((np.array(overlap_bbox_idxs) == Coverlapping_bbox_idx).sum(1), dtype=bool)
|
||||
affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask])
|
||||
|
||||
non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs))
|
||||
# non_regarded_bbox_idxs = list(set(non_regarded_bbox_idxs) - set(handled_bbox_idxs))
|
||||
while len(non_regarded_bbox_idxs) > 0:
|
||||
non_regarded_bbox_idxs_cp = np.copy(non_regarded_bbox_idxs)
|
||||
for non_regarded_bbox_idx in non_regarded_bbox_idxs_cp:
|
||||
Caffected_bbox_idxs = get_connected(non_regarded_bbox_idx, overlap_bbox_idxs)
|
||||
affected_bbox_idxs = np.unique(np.append(affected_bbox_idxs, Caffected_bbox_idxs))
|
||||
|
||||
regarded_bbox_idxs.append(non_regarded_bbox_idx)
|
||||
non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs))
|
||||
|
||||
bbox_idx_group = np.array(regarded_bbox_idxs)
|
||||
bbox_scores = df_collect[bbox_idx_group][:, -1].astype(float)
|
||||
|
||||
bbox_groups[bbox_idx_group] = np.max(bbox_groups) + 1
|
||||
|
||||
remove_idx_combinations = [()]
|
||||
remove_idx_combinations_scores = [0]
|
||||
for r in range(1, len(bbox_idx_group)):
|
||||
remove_idx_combinations.extend(list(itertools.combinations(bbox_idx_group, r=r)))
|
||||
remove_idx_combinations_scores.extend(list(itertools.combinations(bbox_scores, r=r)))
|
||||
for enu, combi_score in enumerate(remove_idx_combinations_scores):
|
||||
remove_idx_combinations_scores[enu] = np.sum(combi_score)
|
||||
if len(bbox_idx_group) > 1:
|
||||
remove_idx_combinations = [remove_idx_combinations[ind] for ind in np.argsort(remove_idx_combinations_scores)]
|
||||
remove_idx_combinations_scores = [remove_idx_combinations_scores[ind] for ind in np.argsort(remove_idx_combinations_scores)]
|
||||
|
||||
|
||||
for remove_idx in remove_idx_combinations:
|
||||
select_bbox_idx_group = list(set(bbox_idx_group) - set(remove_idx))
|
||||
time_overlap_pct, freq_overlap_pct = (
|
||||
compute_time_frequency_overlap_for_bbox_group(select_bbox_idx_group,df_collect))
|
||||
|
||||
if np.all(np.min([time_overlap_pct, freq_overlap_pct], axis=0) < overlap_th):
|
||||
break
|
||||
|
||||
if len(remove_idx) > 0:
|
||||
bbox_groups[np.array(remove_idx)] *= -1
|
||||
|
||||
handled_bbox_idxs.extend(bbox_idx_group)
|
||||
|
||||
return bbox_groups
|
||||
|
||||
def compute_time_frequency_overlap_for_bbox_group(bbox_idx_group, df_collect):
|
||||
time_overlap_pct = np.zeros((len(bbox_idx_group), len(bbox_idx_group)))
|
||||
freq_overlap_pct = np.zeros((len(bbox_idx_group), len(bbox_idx_group)))
|
||||
|
||||
|
||||
for i, j in itertools.product(range(len(bbox_idx_group)), repeat=2):
|
||||
if i == j:
|
||||
continue
|
||||
bb0_idx = bbox_idx_group[i]
|
||||
bb1_idx = bbox_idx_group[j]
|
||||
|
||||
bb0_t0, bb0_t1 = df_collect[bb0_idx][1].astype(float), df_collect[bb0_idx][3].astype(float)
|
||||
bb1_t0, bb1_t1 = df_collect[bb1_idx][1].astype(float), df_collect[bb1_idx][3].astype(float)
|
||||
|
||||
bb0_f0, bb0_f1 = df_collect[bb0_idx][2].astype(float), df_collect[bb0_idx][4].astype(float)
|
||||
bb1_f0, bb1_f1 = df_collect[bb1_idx][2].astype(float), df_collect[bb1_idx][4].astype(float)
|
||||
|
||||
bb_times_idx = np.array([0, 0, 1, 1])
|
||||
bb_times = np.array([bb0_t0, bb0_t1, bb1_t0, bb1_t1])
|
||||
sorted_bb_times_idx = bb_times_idx[bb_times.argsort()]
|
||||
|
||||
if sorted_bb_times_idx[0] == sorted_bb_times_idx[1]:
|
||||
time_overlap_pct[i, j] = 0
|
||||
elif sorted_bb_times_idx[1] == sorted_bb_times_idx[2] == 0:
|
||||
time_overlap_pct[i, j] = 1
|
||||
elif sorted_bb_times_idx[1] == sorted_bb_times_idx[2] == 1:
|
||||
time_overlap_pct[i, j] = (bb1_t1 - bb1_t0) / (bb0_t1 - bb0_t0)
|
||||
else:
|
||||
time_overlap_pct[i, j] = np.diff(sorted(bb_times)[1:3])[0] / ((bb0_t1 - bb0_t0))
|
||||
|
||||
bb_freqs_idx = np.array([0, 0, 1, 1])
|
||||
bb_freqs = np.array([bb0_f0, bb0_f1, bb1_f0, bb1_f1])
|
||||
sorted_bb_freqs_idx = bb_freqs_idx[bb_freqs.argsort()]
|
||||
|
||||
if sorted_bb_freqs_idx[0] == sorted_bb_freqs_idx[1]:
|
||||
freq_overlap_pct[i, j] = 0
|
||||
elif sorted_bb_freqs_idx[1] == sorted_bb_freqs_idx[2] == 0:
|
||||
freq_overlap_pct[i, j] = 1
|
||||
elif sorted_bb_freqs_idx[1] == sorted_bb_freqs_idx[2] == 1:
|
||||
freq_overlap_pct[i, j] = (bb1_f1 - bb1_f0) / (bb0_f1 - bb0_f0)
|
||||
else:
|
||||
freq_overlap_pct[i, j] = np.diff(sorted(bb_freqs)[1:3])[0] / ((bb0_f1 - bb0_f0))
|
||||
|
||||
return time_overlap_pct, freq_overlap_pct
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Extract time, frequency and identity association of bboxes')
|
||||
parser.add_argument('annotations', nargs='?', type=str, help='path to annotations')
|
||||
parser.add_argument('-t', '--tracking_data_path', type=str, help='path to tracking dataa')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -97,7 +97,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1,
|
||||
return fig_title
|
||||
|
||||
|
||||
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1):
|
||||
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1, label_save_folder):
|
||||
|
||||
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
|
||||
|
||||
@ -185,13 +185,16 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
|
||||
all_height
|
||||
]).T
|
||||
|
||||
np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
|
||||
np.savetxt(label_save_folder / Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
|
||||
return bbox_yolo_style
|
||||
|
||||
|
||||
def main(args):
|
||||
folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy'))
|
||||
pic_save_folder = DATA_DIR if not args.inference else (Path('data') / Path(args.folder).name)
|
||||
pic_save_folder = Path('data') / Path(args.folder).name / 'images'
|
||||
label_save_folder = Path('data') / Path(args.folder).name / 'labels'
|
||||
# embed()
|
||||
# quit()
|
||||
|
||||
if len(folders) == 0:
|
||||
print('no datasets containing fill_times.npy found')
|
||||
@ -202,8 +205,10 @@ def main(args):
|
||||
|
||||
else:
|
||||
print('generate inference dataset ... only image output')
|
||||
if not (Path('data') / Path(args.folder).name).exists():
|
||||
(Path('data') / Path(args.folder).name).mkdir(parents=True, exist_ok=True)
|
||||
if not pic_save_folder.exists():
|
||||
pic_save_folder.mkdir(parents=True, exist_ok=True)
|
||||
if not label_save_folder.exists():
|
||||
label_save_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for enu, folder in enumerate(folders):
|
||||
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
|
||||
@ -253,52 +258,7 @@ def main(args):
|
||||
if not args.inference:
|
||||
bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size,
|
||||
fish_baseline_freq_time, fish_baseline_freq,
|
||||
pic_save_str,t0, t1, f0, f1)
|
||||
|
||||
#######################################################################
|
||||
# if False:
|
||||
# if bbox_yolo_style.shape[0] >= 1:
|
||||
# f_res, t_res = freq[1] - freq[0], times[1] - times[0]
|
||||
#
|
||||
# fig_title = (
|
||||
# f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(
|
||||
# ' ', '0')
|
||||
# fig = plt.figure(figsize=IMG_SIZE, num=fig_title)
|
||||
# gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95) #
|
||||
# ax = fig.add_subplot(gs[0, 0])
|
||||
# ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
|
||||
# extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res))
|
||||
# # ax.invert_yaxis()
|
||||
# # ax.axis(False)
|
||||
#
|
||||
# for i in range(len(bbox_df)):
|
||||
# # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32)
|
||||
# Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']]
|
||||
# ax.add_patch(
|
||||
# Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])),
|
||||
# float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600,
|
||||
# float(Cbbox['f1']) - float(Cbbox['f0']),
|
||||
# fill=False, color="white", linestyle='-', linewidth=2, zorder=10)
|
||||
# )
|
||||
#
|
||||
# # print(bbox_yolo_style.T)
|
||||
#
|
||||
# for bbox in bbox_yolo_style:
|
||||
# x0 = bbox[1] - bbox[3]/2 # x_center - width/2
|
||||
# y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2
|
||||
# w = bbox[3]
|
||||
# h = bbox[4]
|
||||
# ax.add_patch(
|
||||
# Rectangle((x0, y0), w, h,
|
||||
# fill=False, color="k", linestyle='--', linewidth=2, zorder=10,
|
||||
# transform=ax.transAxes)
|
||||
# )
|
||||
# plt.show()
|
||||
#######################################################################
|
||||
|
||||
# if not args.inference:
|
||||
# print('save bboxes')
|
||||
# bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',')
|
||||
pic_save_str,t0, t1, f0, f1, label_save_folder)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
||||
@ -306,4 +266,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('-i', "--inference", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ToDo: put "images" in image folder
|
||||
|
||||
main(args)
|
49
inference.py
49
inference.py
@ -19,22 +19,28 @@ from matplotlib.patches import Rectangle
|
||||
|
||||
|
||||
def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
|
||||
|
||||
# embed()
|
||||
# quit()
|
||||
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
|
||||
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
|
||||
ax = fig.add_subplot(gs[0, 0])
|
||||
|
||||
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
|
||||
# ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot')
|
||||
ax.imshow(img_tensor.cpu().squeeze()[0], aspect='auto', cmap='afmhot', vmin=.2)
|
||||
|
||||
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
|
||||
# embed()
|
||||
# quit()
|
||||
if score < detection_threshold:
|
||||
continue
|
||||
# print(x0, y0, x1, y1, l)
|
||||
ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white')
|
||||
# print(score)
|
||||
ax.text(x0 + (x1 - x0) / 2, y0, f'{score:.2f}', ha='center', va='bottom', fontsize=12, color='tab:gray', rotation=90)
|
||||
ax.add_patch(
|
||||
Rectangle((x0, y0),
|
||||
(x1 - x0),
|
||||
(y1 - y0),
|
||||
fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10)
|
||||
fill=False, color="tab:gray", linestyle='-', linewidth=1, zorder=10, alpha=0.8)
|
||||
)
|
||||
|
||||
ax.set_axis_off()
|
||||
@ -42,7 +48,7 @@ def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_na
|
||||
plt.close()
|
||||
# plt.show()
|
||||
|
||||
def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
|
||||
def infere_model(inference_loader, model, dataset_name, detection_th=0.8, figures_only=False):
|
||||
|
||||
print(f'Inference on dataset: {dataset_name}')
|
||||
|
||||
@ -56,9 +62,33 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
|
||||
outputs = model(images)
|
||||
|
||||
for image, img_name, output in zip(images, img_names, outputs):
|
||||
# x0, y0, x1, y1
|
||||
|
||||
yolo_labels = []
|
||||
# for (x0, y0, x1, y1) in output['boxes'].cpu().numpy():
|
||||
for Cbbox, score in zip(output['boxes'].cpu().numpy(), output['scores'].cpu().numpy()):
|
||||
if score < detection_th:
|
||||
continue
|
||||
rel_x0 = Cbbox[0] / image.shape[-2]
|
||||
rel_y0 = Cbbox[1] / image.shape[-2]
|
||||
rel_x1 = Cbbox[2] / image.shape[-2]
|
||||
rel_y1 = Cbbox[3] / image.shape[-2]
|
||||
|
||||
rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2
|
||||
rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2
|
||||
rel_width = rel_x1 - rel_x0
|
||||
rel_height = rel_y1 - rel_y0
|
||||
|
||||
yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height, score])
|
||||
|
||||
if not figures_only:
|
||||
label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt')
|
||||
np.savetxt(label_path, yolo_labels)
|
||||
|
||||
plot_inference(image, img_name, output, detection_th, dataset_name)
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
model = create_model(num_classes=NUM_CLASSES)
|
||||
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
|
||||
@ -73,7 +103,13 @@ def main(args):
|
||||
if not (Path(INFERENCE_OUTDIR)/dataset_name).exists():
|
||||
Path(Path(INFERENCE_OUTDIR)/dataset_name).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
infere_model(inference_loader, model, dataset_name)
|
||||
infere_model(inference_loader, model, dataset_name, figures_only=args.figures_only)
|
||||
|
||||
if not args.figures_only:
|
||||
if (Path('data').absolute() / dataset_name / 'file_dict.csv').exists():
|
||||
(Path('data').absolute() / dataset_name / 'file_dict.csv').unlink()
|
||||
|
||||
|
||||
|
||||
# detection_threshold = 0.8
|
||||
# frame_count = 0
|
||||
@ -96,6 +132,7 @@ def main(args):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
||||
parser.add_argument('folder', type=str, help='folder to infer picutes', default='')
|
||||
parser.add_argument('-f', '--figures_only', action='store_true', help='only generate figures. keek possible existing labels')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
5
train.py
5
train.py
@ -121,11 +121,6 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold):
|
||||
# plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# train_data = create_train_or_test_dataset(DATA_DIR)
|
||||
# test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||
#
|
||||
# train_loader = create_train_loader(train_data)
|
||||
# test_loader = create_valid_loader(test_data)
|
||||
|
||||
train_data, test_data = custom_train_test_split()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user