alot of work. dataset and loader not created

This commit is contained in:
Till Raab 2023-10-23 14:43:01 +02:00
parent 5ca527e39f
commit 3f66b4a39a
4 changed files with 303 additions and 103 deletions

6
custom_utils.py Normal file
View File

@ -0,0 +1,6 @@
def collate_fn(batch):
"""
To handle the data loading as different images may have different number
of objects and to handle varying size tensors as well.
"""
return tuple(zip(*batch))

View File

@ -9,6 +9,7 @@ import torchvision.transforms as T
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
@ -55,7 +56,9 @@ def load_data(folder):
return fill_freqs, fill_times, fill_spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time
def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res):
fig_title = (f'{Path(folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0')
size = (7, 7)
dpi = 256
fig_title = (f'{Path(folder).name}__{times[t_idx0]:.0f}s-{times[t_idx1]:.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0')
fig = plt.figure(figsize=(7, 7), num=fig_title)
gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0) # , bottom=0, left=0, right=1, top=1
gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
@ -64,9 +67,72 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1,
extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res))
ax.axis(False)
plt.savefig(os.path.join('train', fig_title + '.png'), dpi=256)
plt.savefig(os.path.join('train', fig_title), dpi=256)
plt.close()
return fig_title, (size[0]*dpi, size[1]*dpi)
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,
bbox_df, cols, width, height, t0, t1, f0, f1):
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
for id_idx in range(len(fish_freq)):
rise_idx_oi = np.array(rise_idx[id_idx][
(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)], dtype=int)
rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)]
if len(rise_idx_oi) == 0:
continue
closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi]))
closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx]
upper_freq_bound = closest_baseline_freq + rise_size_oi
lower_freq_bound = closest_baseline_freq
left_time_bound = times_v[rise_idx_oi]
right_time_bound = np.zeros_like(left_time_bound)
for enu, Ct_oi in enumerate(times_v[rise_idx_oi]):
Crise_size = rise_size_oi[enu]
Cblf = closest_baseline_freq[enu]
rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
if len(rise_end_t) == 0:
right_time_bound[enu] = np.nan
else:
right_time_bound[enu] = rise_end_t[0]
dt_bbox = right_time_bound - left_time_bound
df_bbox = upper_freq_bound - lower_freq_bound
left_time_bound -= dt_bbox * 0.1
right_time_bound += dt_bbox * 0.1
lower_freq_bound -= df_bbox * 0.1
upper_freq_bound += df_bbox * 0.1
x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int)
x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int)
y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int)
y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int)
bbox = np.array([[pic_save_str for i in range(len(left_time_bound))],
left_time_bound,
right_time_bound,
lower_freq_bound,
upper_freq_bound,
x0, x1, y0, y1])
# test_s = ['a', 'a', 'a', 'a']
tmp_df = pd.DataFrame(
# index= [pic_save_str for i in range(len(left_time_bound))],
# index= test_s,
data=bbox.T,
columns=cols
)
bbox_df = pd.concat([bbox_df, tmp_df], ignore_index=True)
# bbox_df.append(tmp_df)
return bbox_df
def main(args):
min_freq = 200
@ -76,109 +142,129 @@ def main(args):
d_time = 60*15
time_overlap = 60*5
freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = (
load_data(args.folder))
f_res, t_res = freq[1] - freq[0], times[1] - times[0]
unique_ids = np.unique(ident_v[~np.isnan(ident_v)])
pic_base = tqdm(itertools.product(
np.arange(0, times[-1], d_time),
np.arange(min_freq, max_freq, d_freq)
),
total=int(((max_freq-min_freq)//d_freq) * (times[-1] // d_time))
)
for t0, f0 in pic_base:
t1 = t0 + d_time + time_overlap
f1 = f0 + d_freq + freq_overlap
present_freqs = EODf_v[(~np.isnan(ident_v)) &
(t0 <= times_v[idx_v]) &
(times_v[idx_v] <= t1) &
(EODf_v >= f0) &
(EODf_v <= f1)]
if len(present_freqs) == 0:
continue
f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1))
t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1))
s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32)
log_s = torch.log10(s)
transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s))
s_trans = transformed(log_s.unsqueeze(0))
if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')):
cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'x1', 'y0', 'y1']
bbox_df = pd.DataFrame(columns=cols)
else:
bbox_df = pd.read_csv(os.path.join('train', 'bbox_dataset.csv'), sep=',', index_col=0)
cols = list(bbox_df.keys())
eval_files = []
# ToDo: make sure not same file twice
for f in pd.unique(bbox_df['image']):
eval_files.append(f.split('__')[0])
folders = [args.folders]
for enu, folder in enumerate(folders):
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = (
load_data(folder))
f_res, t_res = freq[1] - freq[0], times[1] - times[0]
pic_base = tqdm(itertools.product(
np.arange(0, times[-1], d_time),
np.arange(min_freq, max_freq, d_freq)
),
total=int(((max_freq-min_freq)//d_freq) * (times[-1] // d_time))
)
for t0, f0 in pic_base:
t1 = t0 + d_time + time_overlap
f1 = f0 + d_freq + freq_overlap
present_freqs = EODf_v[(~np.isnan(ident_v)) &
(t0 <= times_v[idx_v]) &
(times_v[idx_v] <= t1) &
(EODf_v >= f0) &
(EODf_v <= f1)]
if len(present_freqs) == 0:
continue
f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1))
t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1))
s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32)
log_s = torch.log10(s)
transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s))
s_trans = transformed(log_s.unsqueeze(0))
if not args.dev:
pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res)
bbox_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size,
fish_baseline_freq_time, fish_baseline_freq,
pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1)
else:
fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0')
fig = plt.figure(figsize=(10, 7), num=fig_title)
gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, top=0.95) # , bottom=0, left=0, right=1, top=1
ax = fig.add_subplot(gs[0, 0])
cax = fig.add_subplot(gs[0, 1])
im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res))
fig.colorbar(im, cax=cax, orientation='vertical')
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
for id_idx in range(len(fish_freq)):
ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k', markersize=4)
rise_idx_oi = np.array(rise_idx[id_idx][
(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)], dtype=int)
rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)]
ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red')
if len(rise_idx_oi) > 0:
closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi]))
closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx]
upper_freq_bound = closest_baseline_freq + rise_size_oi
lower_freq_bound = closest_baseline_freq
left_time_bound = times_v[rise_idx_oi]
right_time_bound = np.zeros_like(left_time_bound)
for enu, Ct_oi in enumerate(times_v[rise_idx_oi]):
Crise_size = rise_size_oi[enu]
Cblf = closest_baseline_freq[enu]
rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
if len(rise_end_t) == 0:
right_time_bound[enu] = np.nan
else:
right_time_bound[enu] = rise_end_t[0]
dt_bbox = right_time_bound - left_time_bound
df_bbox = upper_freq_bound - lower_freq_bound
left_time_bound -= dt_bbox*0.1
right_time_bound += dt_bbox*0.1
lower_freq_bound -= df_bbox*0.1
upper_freq_bound += df_bbox*0.1
print(f'f0: {lower_freq_bound}')
print(f'f1: {upper_freq_bound}')
print(f't0: {left_time_bound}')
print(f't1: {right_time_bound}')
for enu in range(len(left_time_bound)):
if np.isnan(right_time_bound[enu]):
continue
ax.add_patch(
Rectangle((left_time_bound[enu], lower_freq_bound[enu]),
(right_time_bound[enu] - left_time_bound[enu]),
(upper_freq_bound[enu] - lower_freq_bound[enu]),
fill=False, color="white", linewidth=2, zorder=10)
)
plt.show()
if not args.dev:
save_spec_pic(args.folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res)
else:
fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0')
fig = plt.figure(figsize=(10, 7), num=fig_title)
gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, top=0.95) # , bottom=0, left=0, right=1, top=1
ax = fig.add_subplot(gs[0, 0])
cax = fig.add_subplot(gs[0, 1])
im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res))
fig.colorbar(im, cax=cax, orientation='vertical')
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
for id_idx in range(len(fish_freq)):
ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k', markersize=4)
rise_idx_oi = np.array(rise_idx[id_idx][
(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)], dtype=int)
rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) &
(rise_idx[id_idx] <= times_v_idx1) &
(rise_size[id_idx] >= 10)]
ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red')
if len(rise_idx_oi) > 0:
closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi]))
closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx]
upper_freq_bound = closest_baseline_freq + rise_size_oi
lower_freq_bound = closest_baseline_freq
left_time_bound = times_v[rise_idx_oi]
right_time_bound = np.zeros_like(left_time_bound)
for enu, Ct_oi in enumerate(times_v[rise_idx_oi]):
Crise_size = rise_size_oi[enu]
Cblf = closest_baseline_freq[enu]
rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
if len(rise_end_t) == 0:
right_time_bound[enu] = np.nan
else:
right_time_bound[enu] = rise_end_t[0]
dt_bbox = right_time_bound - left_time_bound
df_bbox = upper_freq_bound - lower_freq_bound
left_time_bound -= dt_bbox*0.1
right_time_bound += dt_bbox*0.1
lower_freq_bound -= df_bbox*0.1
upper_freq_bound += df_bbox*0.1
print(f'f0: {lower_freq_bound}')
print(f'f1: {upper_freq_bound}')
print(f't0: {left_time_bound}')
print(f't1: {right_time_bound}')
for enu in range(len(left_time_bound)):
if np.isnan(right_time_bound[enu]):
continue
ax.add_patch(
Rectangle((left_time_bound[enu], lower_freq_bound[enu]),
(right_time_bound[enu] - left_time_bound[enu]),
(upper_freq_bound[enu] - lower_freq_bound[enu]),
fill=False, color="white", linewidth=2, zorder=10)
)
plt.show()
bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')

105
datasets.py Normal file
View File

@ -0,0 +1,105 @@
import os
import glob
import torch
import torchvision
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import Rectangle
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
from confic import (CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE)
from custom_utils import collate_fn
from IPython import embed
class CustomDataset(Dataset):
def __init__(self, dir_path, use_idxs = None):
self.dir_path = dir_path
self.image_paths = glob.glob(f'{self.dir_path}/*.png')
self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths]
self.all_images = np.array(sorted(self.all_images), dtype=str)
if hasattr(use_idxs, '__len__'):
self.all_images = self.all_images[use_idxs]
self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0)
def __getitem__(self, idx):
image_name = self.all_images[idx]
image_path = os.path.join(self.dir_path, image_name)
img = Image.open(image_path)
img_tensor = F.to_tensor(img.convert('RGB'))
Cbbox = self.bbox_df[self.bbox_df['image'] == image_name]
labels = np.ones(len(Cbbox), dtype=int)
boxes = Cbbox.loc[:, ['x0', 'x1', 'y0', 'y1']].values
target = {}
target["boxes"] = boxes
target["labels"] = labels
return img_tensor, target
def __len__(self):
return len(self.all_images)
def create_train_test_dataset(path, test_size=0.2):
files = glob.glob(os.path.join(path, '*.png'))
train_test_idx = np.arange(len(files), dtype=int)
np.random.shuffle(train_test_idx)
train_idx = train_test_idx[int(test_size*len(train_test_idx)):]
test_idx = train_test_idx[:int(test_size*len(train_test_idx))]
train_data = CustomDataset(path, use_idxs=train_idx)
test_data = CustomDataset(path, use_idxs=test_idx)
return train_data, test_data
def create_train_loader(train_dataset, num_workers=0):
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn
)
return train_loader
def create_valid_loader(valid_dataset, num_workers=0):
valid_loader = DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn
)
return valid_loader
if __name__ == '__main__':
train_data, test_data = create_train_test_dataset(TRAIN_DIR)
train_loader = create_train_loader(train_data)
test_loader = create_valid_loader(test_data)
for samples, targets in train_loader:
for s, t in zip(samples, targets):
fig, ax = plt.subplots()
ax.imshow(s.permute(1, 2, 0), aspect='auto')
for (x0, x1, y0, y1), l in zip(t['boxes'], t['labels']):
print(x0, x1, y0, y1, l)
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="white", linewidth=2, zorder=10)
)
plt.show()

3
train.py Normal file
View File

@ -0,0 +1,3 @@
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS)
from model import create_model