Compare commits

...

17 Commits

Author SHA1 Message Date
573665d3f1 hint added where rise training dataset is located 2024-01-19 11:24:23 +01:00
8cd422461a detection of rise time based on surpass of tau freq or steepest slope 2023-12-04 12:22:16 +01:00
cc4408d75a inference plots reworked 2023-12-04 11:25:56 +01:00
22b05aec76 corrected labels that are checked are now exported to the train dataset via copy instead of moving them. 2023-12-04 08:52:26 +01:00
59ffd101f1 when labels are gererated for an already existing dataset delete the file_dict.csv file used by corret_bboxes 2023-12-04 08:50:53 +01:00
5c74ac463b assignement of bboxes to tracks complete. Correct assignement of chirpt start still necessary... increase to >37% of f1-f0 range suggested 2023-11-30 15:19:33 +01:00
866fd3081d elimination of double detection works 2023-11-29 15:00:02 +01:00
f70e74f5e1 inference.py now also writes bbox scores; extract_from_bbox.py detects overlapping bbox groups and eliminates them based on score untill all bboxes have overlap below th... detection of groups is done using time as parameter only. implement also frequency. then sort them as stated previously 2023-11-28 14:53:37 +01:00
93269a96a1 transfere bboxes to time-frequency points 2023-11-21 15:20:44 +01:00
0f0128439e bf 2023-11-16 13:30:26 +01:00
324c690841 model and datasets adapted to new yolo format. reinforced learning is now implemented. correct_bboxes.py now capable of exporting data to train file 2023-11-16 12:19:57 +01:00
3ed6f1102c dont for today. gui is more advanced. work on import/export 2023-11-15 17:32:43 +01:00
3f235f02f3 file dict ssaved and loaded 2023-11-15 09:22:41 +01:00
33f6bb8683 working on bbox correct gui 2023-11-14 12:22:48 +01:00
f286156607 work im progress ... 2023-11-13 15:10:37 +01:00
397be5c51b bbox correcter 2023-11-08 16:04:04 +01:00
c08ede1e76 working on bbox corrector 2023-11-08 15:31:17 +01:00
8 changed files with 719 additions and 133 deletions

View File

@ -6,6 +6,8 @@ spectrogram images. The model itself is a pretrained **FasterRCNN** Model with a
**ResNet50** Backbone. Only the final predictor is replaced to not predict the 91 classes
present in the coco-dataset the model is trained to but the (currently) 1 category it should detect.
HINT: Trained datasets can be found at raab@polarbear:projects/efishSignalDetector/data/rise_training
## Long-Term and major ToDos:
* implement gui to correct bounding boxes
* implement reinforced learning

View File

@ -26,8 +26,8 @@ DELTA_TIME = 60*10
TIME_OVERLAP = 60*1
# output parameters
DATA_DIR = 'data/images'
LABEL_DIR = 'data/labels'
DATA_DIR = 'data/rise_training/images'
LABEL_DIR = 'data/rise_training/labels'
OUTDIR = 'model_outputs'
INFERENCE_OUTDIR = 'inference_outputs'
for required_folders in [DATA_DIR, OUTDIR, INFERENCE_OUTDIR, LABEL_DIR]:

308
correct_bboxes.py Normal file
View File

@ -0,0 +1,308 @@
import os
import sys
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
import pyqtgraph as pg
import shutil
from PIL import Image, ImageOps
import numpy as np
import pandas as pd
import torch
import torchvision.transforms.functional as F
from IPython import embed
import pathlib
NONE_CHECKED_CHECKER_COLORS = ['red', 'black']
class ImageWithBbox(pg.GraphicsLayoutWidget):
# class ImageWithBbox(pg.ImageView):
def __init__(self, img_path, parent=None):
super(ImageWithBbox, self).__init__(parent)
self.img_path = img_path
self.label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt')
# image setup
self.imgItem = pg.ImageItem(ColorMap='viridis')
self.plot_widget = self.addPlot(title="")
self.plot_widget.addItem(self.imgItem, colorMap='viridis')
self.plot_widget.scene().sigMouseClicked.connect(self.addROI)
# self.imgItem.mouseClickEvent()
self.imgItem.getViewBox().invertY(True)
self.imgItem.getViewBox().setMouseEnabled(x=False, y=False)
# get image and labels
img = Image.open(img_path)
img_gray = ImageOps.grayscale(img)
self.img_array = np.array(img_gray).T
self.imgItem.setImage(np.array(self.img_array))
self.plot_widget.setYRange(0, self.img_array.shape[0], padding=0)
self.plot_widget.setXRange(0, self.img_array.shape[1], padding=0)
# label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt')
self.labels = np.loadtxt(self.label_path, delimiter=' ')
if len(self.labels) > 0 and len(self.labels.shape) == 1:
self.labels = np.expand_dims(self.labels, 0)
# add ROIS
self.ROIs = []
for enu, l in enumerate(self.labels):
# x_center, y_center, width, height = l[1:] * self.img_array.shape[1]
x_center = l[1] * self.img_array.shape[0]
y_center = l[2] * self.img_array.shape[1]
width = l[3] * self.img_array.shape[0]
height = l[4] * self.img_array.shape[1]
x0, y0, = x_center-width/2, y_center-height/2
ROI = pg.RectROI((x0, y0), size=(width, height), removable=True, sideScalers=True)
ROI.sigRemoveRequested.connect(self.removeROI)
self.ROIs.append(ROI)
self.plot_widget.addItem(ROI)
def removeROI(self, roi):
if roi in self.ROIs:
self.ROIs.remove(roi)
self.plot_widget.removeItem(roi)
def addROI(self, event):
# Check if the event is a double-click event
if event.double():
pos = event.pos()
# Transform the mouse position to data coordinates
pos_data = self.plot_widget.getViewBox().mapToView(pos)
x, y = pos_data.x(), pos_data.y()
# Create a new ROI at the double-clicked position
new_ROI = pg.RectROI(pos=(x, y), size=(self.img_array.shape[0]*0.05, self.img_array.shape[1]*0.05), removable=True)
new_ROI.sigRemoveRequested.connect(self.removeROI)
self.ROIs.append(new_ROI)
self.plot_widget.addItem(new_ROI)
class Bbox_correct_UI(QMainWindow):
def __init__(self, data_path, parent=None):
super(Bbox_correct_UI, self).__init__(parent)
self.data_path = data_path
self.files = sorted(list(pathlib.Path(self.data_path).absolute().rglob('*images/*.png')))
self.close_without_saving_rois = False
rec = QApplication.desktop().screenGeometry()
height = rec.height()
width = rec.width()
self.resize(int(.8 * width), int(.8 * height))
# widget and layout
self.central_widget = QWidget(self)
self.central_gridLayout = QGridLayout()
self.central_Layout = QHBoxLayout()
self.central_widget.setLayout(self.central_Layout)
self.setCentralWidget(self.central_widget)
###########
self.highlighted_label = None
self.all_labels = []
self.load_or_create_file_dict()
new_name, new_file, new_checked = self.file_dict.iloc[0].values
self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')
# image widget
self.current_img = ImageWithBbox(new_file, parent=self)
self.central_Layout.addWidget(self.current_img, 4)
# image select widget
self.scroll = QScrollArea() # Scroll Area which contains the widgets, set as the centralWidget
self.widget = QWidget() # Widget that contains the collection of Vertical Box
self.vbox = QVBoxLayout()
for i in range(len(self.file_dict)):
label = QLabel(f'{self.file_dict["name"][i]}')
# label.setFrameShape(QLabel.Panel)
# label.setFrameShadow(QLabel.Sunken)
label.setAlignment(Qt.AlignRight)
label.mousePressEvent = lambda event, label=label: self.label_clicked(label)
# label.mousePressEvent = lambda event, label: self.label_clicked(label)
if i == 0:
label.setStyleSheet("border: 2px solid black; "
"color : %s;" % (NONE_CHECKED_CHECKER_COLORS[self.file_dict['checked'][i]]))
self.highlighted_label = label
else:
label.setStyleSheet("border: 1px solid gray; "
"color : %s;" % (NONE_CHECKED_CHECKER_COLORS[self.file_dict['checked'][i]]))
self.vbox.addWidget(label)
self.all_labels.append(label)
self.widget.setLayout(self.vbox)
self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.scroll.setWidgetResizable(True)
self.scroll.setWidget(self.widget)
self.central_Layout.addWidget(self.scroll, 1)
self.add_actions()
self.init_MenuBar()
def load_or_create_file_dict(self):
csvs_im_data_path = list(pathlib.Path(self.data_path).absolute().rglob('*file_dict.csv'))
if len(csvs_im_data_path) == 0:
self.file_dict = pd.DataFrame(
{'name': [f.name for f in self.files],
'files': self.files,
'checked': np.zeros(len(self.files), dtype=int)} # change this to locked
)
else:
self.file_dict = pd.read_csv(csvs_im_data_path[0], sep=',')
def save_file_dict(self):
self.file_dict.to_csv(pathlib.Path(self.data_path)/'file_dict.csv', sep=',', index=False)
def label_clicked(self, clicked_label):
if self.highlighted_label:
hl_mask = self.file_dict['name'] == self.highlighted_label.text()
hl_label_name, _, hl_checked = self.file_dict[hl_mask].values[0]
self.highlighted_label.setStyleSheet("border: 1px solid gray; "
"color: %s;" % (NONE_CHECKED_CHECKER_COLORS[hl_checked]))
mask = self.file_dict['name'] == clicked_label.text()
new_name, new_file, new_checked = self.file_dict[mask].values[0]
clicked_label.setStyleSheet("border: 2px solid black;"
"color: %s;" % (NONE_CHECKED_CHECKER_COLORS[new_checked]))
self.highlighted_label = clicked_label
self.switch_to_new_file(new_file, new_name)
def lock_file(self):
hl_mask = self.file_dict['name'] == self.highlighted_label.text()
hl_label_name, _, hl_checked = self.file_dict[hl_mask].values[0]
# ToDo: do everything with the index instead of mask
df_idx = self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
self.file_dict.at[df_idx, 'checked'] = 1
self.highlighted_label.setStyleSheet("border: 1px solid gray; "
"color: 'black';")
new_idx = df_idx + 1 if df_idx < len(self.file_dict)-1 else 0
new_name, new_file, new_checked = self.file_dict.iloc[new_idx].values
self.all_labels[new_idx].setStyleSheet("border: 2px solid black;"
"color: %s;" % (NONE_CHECKED_CHECKER_COLORS[new_checked]))
self.highlighted_label = self.all_labels[new_idx]
self.switch_to_new_file(new_file, new_name)
def switch_to_new_file(self, new_file, new_name):
self.readout_rois()
self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')
self.central_Layout.removeWidget(self.current_img)
self.current_img = ImageWithBbox(new_file, parent=self)
self.central_Layout.insertWidget(0, self.current_img, 4)
def readout_rois(self):
if self.close_without_saving_rois:
return
new_labels = []
for roi in self.current_img.ROIs:
x0, y0 = roi.pos()
x0 /= self.current_img.img_array.shape[0]
y0 /= self.current_img.img_array.shape[1]
w, h = roi.size()
w /= self.current_img.img_array.shape[0]
h /= self.current_img.img_array.shape[1]
x_center = x0 + w/2
y_center = y0 + h/2
new_labels.append([1, x_center, y_center, w, h])
new_labels = np.array(new_labels)
np.savetxt(self.current_img.label_path, new_labels)
def export_validated_data(self):
fd = QFileDialog()
export_path = fd.getExistingDirectory(self, 'Select Directory')
if export_path:
export_idxs = list(self.file_dict['files'][self.file_dict['checked'] == 1].index)
keep_idxs = []
for export_file_path, export_idx in zip(list(self.file_dict['files'][self.file_dict['checked'] == 1]), export_idxs):
export_image_path = pathlib.Path(export_file_path)
export_label_path = export_image_path.parent.parent / 'labels' / pathlib.Path(export_image_path.name).with_suffix('.txt')
target_image_path = pathlib.Path(export_path) / 'images' / export_image_path.name
target_label_path = pathlib.Path(export_path) / 'labels' / export_label_path.name
if not target_image_path.exists():
# ToDo: this is not tested but should work
# os.rename(export_image_path, target_image_path)
shutil.copy(export_image_path, target_image_path)
# os.rename(export_label_path, target_label_path)
shutil.copy(export_label_path, target_label_path)
else:
keep_idxs.append(export_idx)
# self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
drop_idxs = list(set(export_idxs) - set(keep_idxs))
# self.file_dict = self.file_dict.drop(drop_idxs)
self.save_file_dict()
self.close_without_saving_rois = True
self.close()
else:
print('nope')
pass
def open(self):
pass
def init_MenuBar(self):
menubar = self.menuBar() # needs QMainWindow ?!
file = menubar.addMenu('&File') # create file menu ... accessable with alt+F
file.addActions([self.Act_open, self.Act_export, self.Act_exit])
edit = menubar.addMenu('&Help')
# edit.addActions([self.Act_undo, self.Act_unassigned_funds])
def add_actions(self):
self.lock_file_act = QAction('loc', self)
self.lock_file_act.triggered.connect(self.lock_file)
self.lock_file_act.setShortcut(Qt.Key_Space)
self.addAction(self.lock_file_act)
self.Act_open = QAction('&Open', self)
# self.Act_open.setStatusTip('Open file')
self.Act_open.triggered.connect(self.open) # ToDo: implement this fn
self.Act_export = QAction('&Export', self)
# self.Act_export.setStatusTip('Open file')
self.Act_export.triggered.connect(self.export_validated_data)
self.Act_exit = QAction('&Exit', self) # trigger with alt+E
self.Act_exit.setShortcut(Qt.Key_Q)
self.Act_exit.triggered.connect(self.close)
def closeEvent(self, *args, **kwargs):
super(QMainWindow, self).closeEvent(*args, **kwargs)
self.readout_rois()
self.save_file_dict()
def main_UI():
app = QApplication(sys.argv) # create application
data_path = sys.argv[1]
w = Bbox_correct_UI(data_path) # create window
w.show()
sys.exit(app.exec_())
if __name__ == '__main__':
main_UI()

View File

@ -26,6 +26,8 @@ class InferenceDataset(Dataset):
def __init__(self, dir_path):
self.dir_path = dir_path
self.all_images = sorted(list(Path(self.dir_path).rglob(f'*.png')))
self.file_names = np.array([Path(x).with_suffix('') for x in self.all_images])
# self.images = np.array(sorted(os.listdir(DATA_DIR)))
def __len__(self):
return len(self.all_images)
@ -39,45 +41,6 @@ class InferenceDataset(Dataset):
class CustomDataset(Dataset):
def __init__(self, dir_path, bbox_df):
self.dir_path = dir_path
self.bbox_df = bbox_df
self.all_images = np.array(sorted(pd.unique(self.bbox_df['image'])), dtype=str)
self.image_paths = list(map(lambda x: Path(self.dir_path)/x, self.all_images))
def __getitem__(self, idx):
image_name = self.all_images[idx]
image_path = os.path.join(self.dir_path, image_name)
img = Image.open(image_path)
img_tensor = F.to_tensor(img.convert('RGB'))
Cbbox = self.bbox_df[self.bbox_df['image'] == image_name]
labels = np.ones(len(Cbbox), dtype=int)
boxes = torch.as_tensor(Cbbox.loc[:, ['x0', 'y0', 'x1', 'y1']].values, dtype=torch.float32)
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# no crowd instances
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
target["area"] = area
target["iscrowd"] = iscrowd
image_id = torch.tensor([idx])
target["image_id"] = image_id
target["image_name"] = image_name #ToDo: implement this as 3rd return...
return img_tensor, target
def __len__(self):
return len(self.all_images)
class CustomDataset_v2(Dataset):
def __init__(self, limited_idxs=None):
self.images = np.array(sorted(os.listdir(DATA_DIR)))
self.labels = np.array(sorted(os.listdir(LABEL_DIR)))
@ -99,7 +62,6 @@ class CustomDataset_v2(Dataset):
boxes, labels, area, iscrowd = self.extract_bboxes(annotations)
target = {}
target["boxes"] = boxes
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
@ -147,27 +109,11 @@ def custom_train_test_split():
train_idxs = np.sort(data_idxs[int(0.2 * len(data_idxs)):])
test_idxs = np.sort(data_idxs[:int(0.2 * len(data_idxs))])
train_data = CustomDataset_v2(limited_idxs=train_idxs)
test_data = CustomDataset_v2(limited_idxs=test_idxs)
train_data = CustomDataset(limited_idxs=train_idxs)
test_data = CustomDataset(limited_idxs=test_idxs)
return train_data, test_data
def create_train_or_test_dataset(path, train=True):
if train == True:
pfx='train'
print('Generate train dataset !')
else:
print('Generate test dataset !')
pfx='test'
csv_candidates = list(Path(path).rglob(f'*{pfx}*.csv'))
if len(csv_candidates) == 0:
print(f'no .csv files for *{pfx}* found in {Path(path)}')
quit()
else:
bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0)
return CustomDataset(path, bboxes)
def create_train_loader(train_dataset, num_workers=0):
train_loader = DataLoader(
@ -191,17 +137,6 @@ def create_valid_loader(valid_dataset, num_workers=0):
)
return valid_loader
def create_inference_loader(inference_dataset, num_workers=0):
inference_loader = DataLoader(
inference_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn
)
return inference_loader
if __name__ == '__main__':
train_data, test_data = custom_train_test_split()

347
extract_from_bbox.py Normal file
View File

@ -0,0 +1,347 @@
import itertools
import pathlib
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image
from tqdm.auto import tqdm
from IPython import embed
def extract_time_freq_range_from_filename(img_path):
file_name_str, time_span_str, freq_span_str = str(img_path.with_suffix('').name).split('__')
time_span_str = time_span_str.replace('s', '')
freq_span_str = freq_span_str.replace('Hz', '')
t0, t1 = np.array(time_span_str.split('-'), dtype=float)
f0, f1 = np.array(freq_span_str.split('-'), dtype=float)
return file_name_str, t0, t1, f0, f1
def bbox_to_data(img_path, t_min, t_max, f_min, f_max):
label_path = img_path.parent.parent / 'labels' / img_path.with_suffix('.txt').name
annotations = np.loadtxt(label_path, delimiter=' ')
if len(annotations.shape) == 1:
annotations = np.array([annotations])
if annotations.shape[1] == 0:
print('no rises detected in this window')
return [], []
boxes = np.array([[x[1] - x[3] / 2, 1 - (x[2] + x[4] / 2), x[1] + x[3] / 2, 1 - (x[2] - x[4] / 2)] for x in annotations]) # x0, y0, x1, y1
boxes[:, 0] = boxes[:, 0] * (t_max - t_min) + t_min
boxes[:, 2] = boxes[:, 2] * (t_max - t_min) + t_min
boxes[:, 1] = boxes[:, 1] * (f_max - f_min) + f_min
boxes[:, 3] = boxes[:, 3] * (f_max - f_min) + f_min
scores = annotations[:, 5]
return boxes, scores
def load_wavetracker_data(raw_path):
fund_v = np.load(raw_path.parent / 'fund_v.npy')
ident_v = np.load(raw_path.parent / 'ident_v.npy')
idx_v = np.load(raw_path.parent / 'idx_v.npy')
times = np.load(raw_path.parent / 'times.npy')
return fund_v, ident_v, idx_v, times
def assign_rises_to_ids(raw_path, time_frequency_bboxes, bbox_groups):
def identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v):
mean_id_box_f_rel_to_bbox = []
for id in possible_ids:
id_box_f = fund_v[(ident_v == id) & (times[idx_v] >= t0) & (times[idx_v] <= t1)]
id_box_f_rel_to_bbox = (id_box_f - f0) / (f1 - f0)
mean_id_box_f_rel_to_bbox.append(np.mean(id_box_f_rel_to_bbox))
# print(id, np.mean(id_box_f), f0, f1, np.mean(id_box_f_rel_to_bbox))
most_likely_id = possible_ids[np.argsort(mean_id_box_f_rel_to_bbox)[0]]
return most_likely_id
fund_v, ident_v, idx_v, times = load_wavetracker_data(raw_path)
fig, ax = plt.subplots()
ax.plot(times[idx_v[~np.isnan(ident_v)]], fund_v[~np.isnan(ident_v)], '.')
mask = time_frequency_bboxes['file_name'] == raw_path.parent.name
for index, bbox in time_frequency_bboxes[mask].iterrows():
name, t0, f0, t1, f1, score = (bbox[0], *bbox[1:-2].astype(float))
if bbox_groups[index] == 0:
color = 'tab:green'
elif bbox_groups[index] > 0:
color = 'tab:red'
else:
color = 'k'
ax.add_patch(
Rectangle((t0, f0),
(t1 - t0),
(f1 - f0),
fill=False, color=color, linestyle='--', linewidth=2, zorder=10)
)
ax.text(t1, f1, f'{score:.1%}', ha='right', va='bottom')
possible_ids = np.unique(
ident_v[~np.isnan(ident_v) &
(t0 <= times[idx_v]) &
(t1 >= times[idx_v]) &
(f0 <= fund_v) &
(f1 >= fund_v)]
)
if len(possible_ids) == 1:
assigned_id = possible_ids[0]
time_frequency_bboxes.at[index, 'id'] = assigned_id
elif len(possible_ids) > 1:
assigned_id = identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v)
time_frequency_bboxes.at[index, 'id'] = assigned_id
# rise_id[index] = identify_most_likely_rise_id(possible_ids, t0, t1, f0, f1, fund_v, ident_v, times, idx_v)
else:
continue
rise_start_freq_th = f0 + (f1 - f0) * 0.37
wavetracker_mask = np.arange(len(fund_v))[
(times[idx_v] >= t0) &
(times[idx_v] <= t1) &
(ident_v == assigned_id)
]
if np.sum(fund_v[wavetracker_mask] > rise_start_freq_th) > 0:
# rise start time = moment where rise freq exceeds 37% of bbox freq range ...
rise_start_idx = wavetracker_mask[fund_v[wavetracker_mask] > rise_start_freq_th][0]
rise_time = times[idx_v[rise_start_idx]]
else:
### if this is never the case use the largest slope
rise_start_idx = wavetracker_mask[np.argmax(np.diff(fund_v[wavetracker_mask]))]
rise_time = times[idx_v[rise_start_idx]]
time_frequency_bboxes.at[index, 'event_time'] = rise_time
ax.plot(rise_time, fund_v[rise_start_idx], 'ok')
# embed()
# quit()
# time_frequency_bboxes['id'] = rise_id
# embed()
# plt.show()
plt.close()
return time_frequency_bboxes
def find_overlapping_bboxes(df_collect):
file_names = np.array(df_collect)[:, 0]
bboxes = np.array(df_collect)[:, 1:].astype(float)
overlap_bbox_idxs = []
for file_name in tqdm(np.unique(file_names)):
file_bbox_idxs = np.arange(len(file_names))[file_names == file_name]
for ind0, ind1 in itertools.combinations(file_bbox_idxs, r=2):
bb0 = bboxes[ind0]
bb1 = bboxes[ind1]
t0_0, f0_0, t0_1, f0_1 = bb0[:-1]
t1_0, f1_0, t1_1, f1_1 = bb1[:-1]
bb_times = np.array([t0_0, t0_1, t1_0, t1_1])
bb_time_associate = np.array([0, 0, 1, 1])
time_helper = bb_time_associate[np.argsort(bb_times)]
if time_helper[0] == time_helper[1]:
# no temporal overlap
continue
# check freq overlap
bb_freqs = np.array([f0_0, f0_1, f1_0, f1_1])
bb_freq_associate = np.array([0, 0, 1, 1])
freq_helper = bb_freq_associate[np.argsort(bb_freqs)]
if freq_helper[0] == freq_helper[1]:
continue
overlap_bbox_idxs.append((ind0, ind1))
return np.asarray(overlap_bbox_idxs)
def main(args):
img_paths = sorted(list(pathlib.Path(args.annotations).absolute().rglob('*.png')))
df_collect = []
for img_path in img_paths:
# convert to time_frequency
file_name_str, t_min, t_max, f_min, f_max = extract_time_freq_range_from_filename(img_path)
boxes, scores = bbox_to_data(img_path, t_min, t_max, f_min, f_max ) # t0, t1, f0, f1
# store values in df
if not len(boxes) == 0:
for (t0, f0, t1, f1), s in zip(boxes, scores):
df_collect.append([file_name_str, t0, f0, t1, f1, s])
df_collect = np.array(df_collect)
overlap_bbox_idxs = find_overlapping_bboxes(df_collect)
bbox_groups = delete_double_boxes(overlap_bbox_idxs, df_collect)
time_frequency_bboxes = pd.DataFrame(data= np.array(df_collect), columns=['file_name', 't0', 'f0', 't1', 'f1', 'score'])
time_frequency_bboxes['id'] = np.full(len(time_frequency_bboxes), np.nan)
time_frequency_bboxes['event_time'] = np.full(len(time_frequency_bboxes), np.nan)
###########################################
# for file_name in time_frequency_bboxes['file_name'].unique():
# fig, ax = plt.subplots()
#
# mask = time_frequency_bboxes['file_name'] == file_name
# for index, bbox in time_frequency_bboxes[mask].iterrows():
# name, t0, f0, t1, f1 = (bbox[0], *bbox[1:-1].astype(float))
# if bbox_groups[index] == 0:
# color = 'tab:green'
# elif bbox_groups[index] > 0:
# color = 'tab:red'
# else:
# color = 'k'
#
# ax.add_patch(
# Rectangle((t0, f0),
# (t1 - t0),
# (f1 - f0),
# fill=False, color=color, linestyle='--', linewidth=2, zorder=10)
# )
# # ax.set_xlim(float(time_frequency_bboxes[mask]['t0'].min()), float(time_frequency_bboxes[mask]['t1'].max()))
# ax.set_xlim(0, float(time_frequency_bboxes[mask]['t1'].max()))
# # ax.set_ylim(float(time_frequency_bboxes[mask]['f0'].min()), float(time_frequency_bboxes[mask]['f1'].max()))
# ax.set_ylim(400, 1200)
# plt.show()
# exit()
###########################################
if args.tracking_data_path:
file_paths = sorted(list(pathlib.Path(args.tracking_data_path).absolute().rglob('*.raw')))
for raw_path in file_paths:
if not raw_path.parent.name in time_frequency_bboxes['file_name'].to_list():
continue
time_frequency_bboxes = assign_rises_to_ids(raw_path, time_frequency_bboxes, bbox_groups)
for raw_path in file_paths:
# mask = (time_frequency_bboxes['file_name'] == raw_path.parent.name)
mask = ((time_frequency_bboxes['file_name'] == raw_path.parent.name) & (~np.isnan(time_frequency_bboxes['id'])))
save_df = pd.DataFrame(time_frequency_bboxes[mask][['t0', 't1', 'f0', 'f1', 'score', 'id', 'event_time']].values, columns=['t0', 't1', 'f0', 'f1', 'score', 'id', 'event_time'])
save_df['label'] = np.ones(len(save_df), dtype=int)
save_df.to_csv(raw_path.parent / 'risedetector_bboxes.csv', sep = ',', index = False)
quit()
def delete_double_boxes(overlap_bbox_idxs, df_collect, overlap_th = 0.2):
def get_connected(non_regarded_bbox_idx, overlap_bbox_idxs):
mask = np.array((np.array(overlap_bbox_idxs) == non_regarded_bbox_idx).sum(1), dtype=bool)
affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask])
return affected_bbox_idxs
handled_bbox_idxs = []
bbox_groups = np.zeros(len(df_collect))
# detele_bbox_idxs = []
for Coverlapping_bbox_idx in tqdm(np.unique(overlap_bbox_idxs)):
if Coverlapping_bbox_idx in handled_bbox_idxs:
continue
regarded_bbox_idxs = [Coverlapping_bbox_idx]
mask = np.array((np.array(overlap_bbox_idxs) == Coverlapping_bbox_idx).sum(1), dtype=bool)
affected_bbox_idxs = np.unique(overlap_bbox_idxs[mask])
non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs))
# non_regarded_bbox_idxs = list(set(non_regarded_bbox_idxs) - set(handled_bbox_idxs))
while len(non_regarded_bbox_idxs) > 0:
non_regarded_bbox_idxs_cp = np.copy(non_regarded_bbox_idxs)
for non_regarded_bbox_idx in non_regarded_bbox_idxs_cp:
Caffected_bbox_idxs = get_connected(non_regarded_bbox_idx, overlap_bbox_idxs)
affected_bbox_idxs = np.unique(np.append(affected_bbox_idxs, Caffected_bbox_idxs))
regarded_bbox_idxs.append(non_regarded_bbox_idx)
non_regarded_bbox_idxs = list(set(affected_bbox_idxs) - set(regarded_bbox_idxs))
bbox_idx_group = np.array(regarded_bbox_idxs)
bbox_scores = df_collect[bbox_idx_group][:, -1].astype(float)
bbox_groups[bbox_idx_group] = np.max(bbox_groups) + 1
remove_idx_combinations = [()]
remove_idx_combinations_scores = [0]
for r in range(1, len(bbox_idx_group)):
remove_idx_combinations.extend(list(itertools.combinations(bbox_idx_group, r=r)))
remove_idx_combinations_scores.extend(list(itertools.combinations(bbox_scores, r=r)))
for enu, combi_score in enumerate(remove_idx_combinations_scores):
remove_idx_combinations_scores[enu] = np.sum(combi_score)
if len(bbox_idx_group) > 1:
remove_idx_combinations = [remove_idx_combinations[ind] for ind in np.argsort(remove_idx_combinations_scores)]
remove_idx_combinations_scores = [remove_idx_combinations_scores[ind] for ind in np.argsort(remove_idx_combinations_scores)]
for remove_idx in remove_idx_combinations:
select_bbox_idx_group = list(set(bbox_idx_group) - set(remove_idx))
time_overlap_pct, freq_overlap_pct = (
compute_time_frequency_overlap_for_bbox_group(select_bbox_idx_group,df_collect))
if np.all(np.min([time_overlap_pct, freq_overlap_pct], axis=0) < overlap_th):
break
if len(remove_idx) > 0:
bbox_groups[np.array(remove_idx)] *= -1
handled_bbox_idxs.extend(bbox_idx_group)
return bbox_groups
def compute_time_frequency_overlap_for_bbox_group(bbox_idx_group, df_collect):
time_overlap_pct = np.zeros((len(bbox_idx_group), len(bbox_idx_group)))
freq_overlap_pct = np.zeros((len(bbox_idx_group), len(bbox_idx_group)))
for i, j in itertools.product(range(len(bbox_idx_group)), repeat=2):
if i == j:
continue
bb0_idx = bbox_idx_group[i]
bb1_idx = bbox_idx_group[j]
bb0_t0, bb0_t1 = df_collect[bb0_idx][1].astype(float), df_collect[bb0_idx][3].astype(float)
bb1_t0, bb1_t1 = df_collect[bb1_idx][1].astype(float), df_collect[bb1_idx][3].astype(float)
bb0_f0, bb0_f1 = df_collect[bb0_idx][2].astype(float), df_collect[bb0_idx][4].astype(float)
bb1_f0, bb1_f1 = df_collect[bb1_idx][2].astype(float), df_collect[bb1_idx][4].astype(float)
bb_times_idx = np.array([0, 0, 1, 1])
bb_times = np.array([bb0_t0, bb0_t1, bb1_t0, bb1_t1])
sorted_bb_times_idx = bb_times_idx[bb_times.argsort()]
if sorted_bb_times_idx[0] == sorted_bb_times_idx[1]:
time_overlap_pct[i, j] = 0
elif sorted_bb_times_idx[1] == sorted_bb_times_idx[2] == 0:
time_overlap_pct[i, j] = 1
elif sorted_bb_times_idx[1] == sorted_bb_times_idx[2] == 1:
time_overlap_pct[i, j] = (bb1_t1 - bb1_t0) / (bb0_t1 - bb0_t0)
else:
time_overlap_pct[i, j] = np.diff(sorted(bb_times)[1:3])[0] / ((bb0_t1 - bb0_t0))
bb_freqs_idx = np.array([0, 0, 1, 1])
bb_freqs = np.array([bb0_f0, bb0_f1, bb1_f0, bb1_f1])
sorted_bb_freqs_idx = bb_freqs_idx[bb_freqs.argsort()]
if sorted_bb_freqs_idx[0] == sorted_bb_freqs_idx[1]:
freq_overlap_pct[i, j] = 0
elif sorted_bb_freqs_idx[1] == sorted_bb_freqs_idx[2] == 0:
freq_overlap_pct[i, j] = 1
elif sorted_bb_freqs_idx[1] == sorted_bb_freqs_idx[2] == 1:
freq_overlap_pct[i, j] = (bb1_f1 - bb1_f0) / (bb0_f1 - bb0_f0)
else:
freq_overlap_pct[i, j] = np.diff(sorted(bb_freqs)[1:3])[0] / ((bb0_f1 - bb0_f0))
return time_overlap_pct, freq_overlap_pct
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Extract time, frequency and identity association of bboxes')
parser.add_argument('annotations', nargs='?', type=str, help='path to annotations')
parser.add_argument('-t', '--tracking_data_path', type=str, help='path to tracking dataa')
args = parser.parse_args()
main(args)

View File

@ -97,7 +97,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1,
return fig_title
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1):
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1, label_save_folder):
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
@ -185,13 +185,16 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
all_height
]).T
np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
np.savetxt(label_save_folder / Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
return bbox_yolo_style
def main(args):
folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy'))
pic_save_folder = DATA_DIR if not args.inference else (Path('data') / Path(args.folder).name)
pic_save_folder = Path('data') / Path(args.folder).name / 'images'
label_save_folder = Path('data') / Path(args.folder).name / 'labels'
# embed()
# quit()
if len(folders) == 0:
print('no datasets containing fill_times.npy found')
@ -202,8 +205,10 @@ def main(args):
else:
print('generate inference dataset ... only image output')
if not (Path('data') / Path(args.folder).name).exists():
(Path('data') / Path(args.folder).name).mkdir(parents=True, exist_ok=True)
if not pic_save_folder.exists():
pic_save_folder.mkdir(parents=True, exist_ok=True)
if not label_save_folder.exists():
label_save_folder.mkdir(parents=True, exist_ok=True)
for enu, folder in enumerate(folders):
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
@ -253,52 +258,7 @@ def main(args):
if not args.inference:
bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size,
fish_baseline_freq_time, fish_baseline_freq,
pic_save_str,t0, t1, f0, f1)
#######################################################################
# if False:
# if bbox_yolo_style.shape[0] >= 1:
# f_res, t_res = freq[1] - freq[0], times[1] - times[0]
#
# fig_title = (
# f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(
# ' ', '0')
# fig = plt.figure(figsize=IMG_SIZE, num=fig_title)
# gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95) #
# ax = fig.add_subplot(gs[0, 0])
# ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
# extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res))
# # ax.invert_yaxis()
# # ax.axis(False)
#
# for i in range(len(bbox_df)):
# # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32)
# Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']]
# ax.add_patch(
# Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])),
# float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600,
# float(Cbbox['f1']) - float(Cbbox['f0']),
# fill=False, color="white", linestyle='-', linewidth=2, zorder=10)
# )
#
# # print(bbox_yolo_style.T)
#
# for bbox in bbox_yolo_style:
# x0 = bbox[1] - bbox[3]/2 # x_center - width/2
# y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2
# w = bbox[3]
# h = bbox[4]
# ax.add_patch(
# Rectangle((x0, y0), w, h,
# fill=False, color="k", linestyle='--', linewidth=2, zorder=10,
# transform=ax.transAxes)
# )
# plt.show()
#######################################################################
# if not args.inference:
# print('save bboxes')
# bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',')
pic_save_str,t0, t1, f0, f1, label_save_folder)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
@ -306,4 +266,6 @@ if __name__ == '__main__':
parser.add_argument('-i', "--inference", action="store_true")
args = parser.parse_args()
# ToDo: put "images" in image folder
main(args)

View File

@ -19,22 +19,28 @@ from matplotlib.patches import Rectangle
def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
# embed()
# quit()
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
ax = fig.add_subplot(gs[0, 0])
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
# ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot')
ax.imshow(img_tensor.cpu().squeeze()[0], aspect='auto', cmap='afmhot', vmin=.2)
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
# embed()
# quit()
if score < detection_threshold:
continue
# print(x0, y0, x1, y1, l)
ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white')
# print(score)
ax.text(x0 + (x1 - x0) / 2, y0, f'{score:.2f}', ha='center', va='bottom', fontsize=12, color='tab:gray', rotation=90)
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10)
fill=False, color="tab:gray", linestyle='-', linewidth=1, zorder=10, alpha=0.8)
)
ax.set_axis_off()
@ -42,7 +48,7 @@ def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_na
plt.close()
# plt.show()
def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
def infere_model(inference_loader, model, dataset_name, detection_th=0.8, figures_only=False):
print(f'Inference on dataset: {dataset_name}')
@ -56,9 +62,33 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
outputs = model(images)
for image, img_name, output in zip(images, img_names, outputs):
# x0, y0, x1, y1
yolo_labels = []
# for (x0, y0, x1, y1) in output['boxes'].cpu().numpy():
for Cbbox, score in zip(output['boxes'].cpu().numpy(), output['scores'].cpu().numpy()):
if score < detection_th:
continue
rel_x0 = Cbbox[0] / image.shape[-2]
rel_y0 = Cbbox[1] / image.shape[-2]
rel_x1 = Cbbox[2] / image.shape[-2]
rel_y1 = Cbbox[3] / image.shape[-2]
rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2
rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2
rel_width = rel_x1 - rel_x0
rel_height = rel_y1 - rel_y0
yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height, score])
if not figures_only:
label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt')
np.savetxt(label_path, yolo_labels)
plot_inference(image, img_name, output, detection_th, dataset_name)
def main(args):
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
@ -73,7 +103,13 @@ def main(args):
if not (Path(INFERENCE_OUTDIR)/dataset_name).exists():
Path(Path(INFERENCE_OUTDIR)/dataset_name).mkdir(parents=True, exist_ok=True)
infere_model(inference_loader, model, dataset_name)
infere_model(inference_loader, model, dataset_name, figures_only=args.figures_only)
if not args.figures_only:
if (Path('data').absolute() / dataset_name / 'file_dict.csv').exists():
(Path('data').absolute() / dataset_name / 'file_dict.csv').unlink()
# detection_threshold = 0.8
# frame_count = 0
@ -96,6 +132,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
parser.add_argument('folder', type=str, help='folder to infer picutes', default='')
parser.add_argument('-f', '--figures_only', action='store_true', help='only generate figures. keek possible existing labels')
args = parser.parse_args()
main(args)

View File

@ -121,11 +121,6 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold):
# plt.show()
if __name__ == '__main__':
# train_data = create_train_or_test_dataset(DATA_DIR)
# test_data = create_train_or_test_dataset(DATA_DIR, train=False)
#
# train_loader = create_train_loader(train_data)
# test_loader = create_valid_loader(test_data)
train_data, test_data = custom_train_test_split()