changed TRAIN_DIR to DATA_DIR in all files since it is more representative in this code. test and train datasets are defined by csv files
This commit is contained in:
parent
16bca48cbe
commit
030dcfab43
@ -15,7 +15,7 @@ CLASSES = ['__backgroud__', '1']
|
|||||||
|
|
||||||
NUM_CLASSES = len(CLASSES)
|
NUM_CLASSES = len(CLASSES)
|
||||||
|
|
||||||
TRAIN_DIR = 'data/train'
|
DATA_DIR = 'data/dataset'
|
||||||
OUTDIR = 'model_outputs'
|
OUTDIR = 'model_outputs'
|
||||||
INFERENCE_OUTDIR = 'inference_outputs'
|
INFERENCE_OUTDIR = 'inference_outputs'
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1,
|
|||||||
extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res))
|
extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res))
|
||||||
ax.axis(False)
|
ax.axis(False)
|
||||||
|
|
||||||
plt.savefig(os.path.join('train', fig_title), dpi=256)
|
plt.savefig(os.path.join('dataset', fig_title), dpi=256)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
return fig_title, (size[0]*dpi, size[1]*dpi)
|
return fig_title, (size[0]*dpi, size[1]*dpi)
|
||||||
@ -197,13 +197,13 @@ def main(args):
|
|||||||
|
|
||||||
# init dataframe if not existent so far
|
# init dataframe if not existent so far
|
||||||
eval_files = []
|
eval_files = []
|
||||||
if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')):
|
if not os.path.exists(os.path.join('dataset', 'bbox_dataset.csv')):
|
||||||
cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1']
|
cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1']
|
||||||
bbox_df = pd.DataFrame(columns=cols)
|
bbox_df = pd.DataFrame(columns=cols)
|
||||||
|
|
||||||
# else load datafile ... and check for already regarded files (eval_files)
|
# else load datafile ... and check for already regarded files (eval_files)
|
||||||
else:
|
else:
|
||||||
bbox_df = pd.read_csv(os.path.join('train', 'bbox_dataset.csv'), sep=',', index_col=0)
|
bbox_df = pd.read_csv(os.path.join('dataset', 'bbox_dataset.csv'), sep=',', index_col=0)
|
||||||
cols = list(bbox_df.keys())
|
cols = list(bbox_df.keys())
|
||||||
# ToDo: make sure not same file twice
|
# ToDo: make sure not same file twice
|
||||||
for f in pd.unique(bbox_df['image']):
|
for f in pd.unique(bbox_df['image']):
|
||||||
@ -266,7 +266,7 @@ def main(args):
|
|||||||
|
|
||||||
if not args.dev:
|
if not args.dev:
|
||||||
print('save')
|
print('save')
|
||||||
bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',')
|
bbox_df.to_csv(os.path.join('dataset', 'bbox_dataset.csv'), columns=cols, sep=',')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
||||||
|
@ -41,4 +41,4 @@ def main(path):
|
|||||||
test_bbox.to_csv(path/'bbox_test.csv', columns=cols, sep=',')
|
test_bbox.to_csv(path/'bbox_test.csv', columns=cols, sep=',')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main(Path('./train'))
|
main(Path('./dataset'))
|
@ -14,7 +14,7 @@ from pathlib import Path
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from confic import (CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE)
|
from confic import (CLASSES, RESIZE_TO, DATA_DIR, BATCH_SIZE)
|
||||||
from custom_utils import collate_fn
|
from custom_utils import collate_fn
|
||||||
|
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
@ -96,8 +96,8 @@ def create_valid_loader(valid_dataset, num_workers=0):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
# train_data, test_data = create_train_test_dataset(TRAIN_DIR)
|
# train_data, test_data = create_train_test_dataset(TRAIN_DIR)
|
||||||
train_data = create_train_or_test_dataset(TRAIN_DIR)
|
train_data = create_train_or_test_dataset(DATA_DIR)
|
||||||
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||||
|
|
||||||
train_loader = create_train_loader(train_data)
|
train_loader = create_train_loader(train_data)
|
||||||
test_loader = create_valid_loader(test_data)
|
test_loader = create_valid_loader(test_data)
|
||||||
|
@ -6,7 +6,7 @@ import os
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from model import create_model
|
from model import create_model
|
||||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
|
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, DATA_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
|
||||||
from datasets import create_train_or_test_dataset, create_valid_loader
|
from datasets import create_train_or_test_dataset, create_valid_loader
|
||||||
|
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
|||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
model.to(DEVICE).eval()
|
model.to(DEVICE).eval()
|
||||||
|
|
||||||
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||||
test_loader = create_valid_loader(test_data)
|
test_loader = create_valid_loader(test_data)
|
||||||
|
|
||||||
infere_model(test_loader, model)
|
infere_model(test_loader, model)
|
||||||
|
6
train.py
6
train.py
@ -1,4 +1,4 @@
|
|||||||
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, TRAIN_DIR)
|
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR)
|
||||||
from model import create_model
|
from model import create_model
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
@ -65,8 +65,8 @@ def validate(test_loader, model, val_loss):
|
|||||||
return val_loss
|
return val_loss
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_data = create_train_or_test_dataset(TRAIN_DIR)
|
train_data = create_train_or_test_dataset(DATA_DIR)
|
||||||
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||||
|
|
||||||
train_loader = create_train_loader(train_data)
|
train_loader = create_train_loader(train_data)
|
||||||
test_loader = create_valid_loader(test_data)
|
test_loader = create_valid_loader(test_data)
|
||||||
|
Loading…
Reference in New Issue
Block a user