diff --git a/data/generate_dataset.py b/data/generate_dataset.py index b5d58a5..a361141 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -18,6 +18,9 @@ import os from IPython import embed +from matplotlib.patches import Rectangle +# from matplotlib.collections import PatchCollection + def load_data(folder): fill_freqs, fill_times, fill_spec = [], [], [] @@ -127,7 +130,7 @@ def main(args): (rise_idx[id_idx] >= times_v_idx0) & (rise_idx[id_idx] <= times_v_idx1) & (rise_size[id_idx] >= 10)], dtype=int) - rise_size_oi = rise_idx[id_idx][(rise_idx[id_idx] >= times_v_idx0) & + rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & (rise_idx[id_idx] <= times_v_idx1) & (rise_size[id_idx] >= 10)] @@ -152,10 +155,13 @@ def main(args): right_time_bound[enu] = np.nan else: right_time_bound[enu] = rise_end_t[0] + # Create patch collection with specified colour/alpha + for enu in range(len(left_time_bound)): + if ~np.isnan(right_time_bound): + continue + ax.add_patch.Rectangle((left_time_bound, right_time_bound), (right_time_bound - left_time_bound), (upper_freq_bound - lower_freq_bound), + fill=False, color="k", linewidth=2) - - embed() - quit() plt.show()