changed TRAIN_DIR to DATA_DIR in all files since it is more representative in this code. test and train datasets are defined by csv files
This commit is contained in:
parent
16bca48cbe
commit
030dcfab43
@ -15,7 +15,7 @@ CLASSES = ['__backgroud__', '1']
|
||||
|
||||
NUM_CLASSES = len(CLASSES)
|
||||
|
||||
TRAIN_DIR = 'data/train'
|
||||
DATA_DIR = 'data/dataset'
|
||||
OUTDIR = 'model_outputs'
|
||||
INFERENCE_OUTDIR = 'inference_outputs'
|
||||
|
||||
|
@ -67,7 +67,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1,
|
||||
extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res))
|
||||
ax.axis(False)
|
||||
|
||||
plt.savefig(os.path.join('train', fig_title), dpi=256)
|
||||
plt.savefig(os.path.join('dataset', fig_title), dpi=256)
|
||||
plt.close()
|
||||
|
||||
return fig_title, (size[0]*dpi, size[1]*dpi)
|
||||
@ -197,13 +197,13 @@ def main(args):
|
||||
|
||||
# init dataframe if not existent so far
|
||||
eval_files = []
|
||||
if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')):
|
||||
if not os.path.exists(os.path.join('dataset', 'bbox_dataset.csv')):
|
||||
cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1']
|
||||
bbox_df = pd.DataFrame(columns=cols)
|
||||
|
||||
# else load datafile ... and check for already regarded files (eval_files)
|
||||
else:
|
||||
bbox_df = pd.read_csv(os.path.join('train', 'bbox_dataset.csv'), sep=',', index_col=0)
|
||||
bbox_df = pd.read_csv(os.path.join('dataset', 'bbox_dataset.csv'), sep=',', index_col=0)
|
||||
cols = list(bbox_df.keys())
|
||||
# ToDo: make sure not same file twice
|
||||
for f in pd.unique(bbox_df['image']):
|
||||
@ -266,7 +266,7 @@ def main(args):
|
||||
|
||||
if not args.dev:
|
||||
print('save')
|
||||
bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',')
|
||||
bbox_df.to_csv(os.path.join('dataset', 'bbox_dataset.csv'), columns=cols, sep=',')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
||||
|
@ -41,4 +41,4 @@ def main(path):
|
||||
test_bbox.to_csv(path/'bbox_test.csv', columns=cols, sep=',')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(Path('./train'))
|
||||
main(Path('./dataset'))
|
@ -14,7 +14,7 @@ from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from confic import (CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE)
|
||||
from confic import (CLASSES, RESIZE_TO, DATA_DIR, BATCH_SIZE)
|
||||
from custom_utils import collate_fn
|
||||
|
||||
from IPython import embed
|
||||
@ -96,8 +96,8 @@ def create_valid_loader(valid_dataset, num_workers=0):
|
||||
if __name__ == '__main__':
|
||||
|
||||
# train_data, test_data = create_train_test_dataset(TRAIN_DIR)
|
||||
train_data = create_train_or_test_dataset(TRAIN_DIR)
|
||||
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||
train_data = create_train_or_test_dataset(DATA_DIR)
|
||||
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||
|
||||
train_loader = create_train_loader(train_data)
|
||||
test_loader = create_valid_loader(test_data)
|
||||
|
@ -6,7 +6,7 @@ import os
|
||||
from PIL import Image
|
||||
|
||||
from model import create_model
|
||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
|
||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, DATA_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
|
||||
from datasets import create_train_or_test_dataset, create_valid_loader
|
||||
|
||||
from IPython import embed
|
||||
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.to(DEVICE).eval()
|
||||
|
||||
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||
test_loader = create_valid_loader(test_data)
|
||||
|
||||
infere_model(test_loader, model)
|
||||
|
6
train.py
6
train.py
@ -1,4 +1,4 @@
|
||||
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, TRAIN_DIR)
|
||||
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR)
|
||||
from model import create_model
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
@ -65,8 +65,8 @@ def validate(test_loader, model, val_loss):
|
||||
return val_loss
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_data = create_train_or_test_dataset(TRAIN_DIR)
|
||||
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||
train_data = create_train_or_test_dataset(DATA_DIR)
|
||||
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||
|
||||
train_loader = create_train_loader(train_data)
|
||||
test_loader = create_valid_loader(test_data)
|
||||
|
Loading…
Reference in New Issue
Block a user