diff --git a/data/generate_dataset.py b/data/generate_dataset.py index 6f5cb0c..401c151 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -1,3 +1,5 @@ +import time + import numpy as np import argparse import torch @@ -123,7 +125,9 @@ def main(args): (rise_size[id_idx] >= 10)], dtype=int) ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red') - plt.show() + plt.show(block=False) + time.sleep(2) + plt.close() if __name__ == '__main__':