changed TRAIN_DIR to DATA_DIR in all files since it is more representative in this code. test and train datasets are defined by csv files

This commit is contained in:
Till Raab 2023-10-26 12:47:29 +02:00
parent 16bca48cbe
commit 030dcfab43
6 changed files with 14 additions and 14 deletions

View File

@ -15,7 +15,7 @@ CLASSES = ['__backgroud__', '1']
NUM_CLASSES = len(CLASSES)
TRAIN_DIR = 'data/train'
DATA_DIR = 'data/dataset'
OUTDIR = 'model_outputs'
INFERENCE_OUTDIR = 'inference_outputs'

View File

@ -67,7 +67,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1,
extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res))
ax.axis(False)
plt.savefig(os.path.join('train', fig_title), dpi=256)
plt.savefig(os.path.join('dataset', fig_title), dpi=256)
plt.close()
return fig_title, (size[0]*dpi, size[1]*dpi)
@ -197,13 +197,13 @@ def main(args):
# init dataframe if not existent so far
eval_files = []
if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')):
if not os.path.exists(os.path.join('dataset', 'bbox_dataset.csv')):
cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1']
bbox_df = pd.DataFrame(columns=cols)
# else load datafile ... and check for already regarded files (eval_files)
else:
bbox_df = pd.read_csv(os.path.join('train', 'bbox_dataset.csv'), sep=',', index_col=0)
bbox_df = pd.read_csv(os.path.join('dataset', 'bbox_dataset.csv'), sep=',', index_col=0)
cols = list(bbox_df.keys())
# ToDo: make sure not same file twice
for f in pd.unique(bbox_df['image']):
@ -266,7 +266,7 @@ def main(args):
if not args.dev:
print('save')
bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',')
bbox_df.to_csv(os.path.join('dataset', 'bbox_dataset.csv'), columns=cols, sep=',')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')

View File

@ -41,4 +41,4 @@ def main(path):
test_bbox.to_csv(path/'bbox_test.csv', columns=cols, sep=',')
if __name__ == '__main__':
main(Path('./train'))
main(Path('./dataset'))

View File

@ -14,7 +14,7 @@ from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
from confic import (CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE)
from confic import (CLASSES, RESIZE_TO, DATA_DIR, BATCH_SIZE)
from custom_utils import collate_fn
from IPython import embed
@ -96,8 +96,8 @@ def create_valid_loader(valid_dataset, num_workers=0):
if __name__ == '__main__':
# train_data, test_data = create_train_test_dataset(TRAIN_DIR)
train_data = create_train_or_test_dataset(TRAIN_DIR)
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
train_data = create_train_or_test_dataset(DATA_DIR)
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
train_loader = create_train_loader(train_data)
test_loader = create_valid_loader(test_data)

View File

@ -6,7 +6,7 @@ import os
from PIL import Image
from model import create_model
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, DATA_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
from datasets import create_train_or_test_dataset, create_valid_loader
from IPython import embed
@ -72,7 +72,7 @@ if __name__ == '__main__':
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
test_loader = create_valid_loader(test_data)
infere_model(test_loader, model)

View File

@ -1,4 +1,4 @@
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, TRAIN_DIR)
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR)
from model import create_model
from tqdm.auto import tqdm
@ -65,8 +65,8 @@ def validate(test_loader, model, val_loss):
return val_loss
if __name__ == '__main__':
train_data = create_train_or_test_dataset(TRAIN_DIR)
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
train_data = create_train_or_test_dataset(DATA_DIR)
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
train_loader = create_train_loader(train_data)
test_loader = create_valid_loader(test_data)