diff --git a/data/generate_dataset.py b/data/generate_dataset.py
index 1dc3cbd..d8fbb26 100644
--- a/data/generate_dataset.py
+++ b/data/generate_dataset.py
@@ -117,7 +117,7 @@ def main(args):
             times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
             for id_idx in range(len(fish_freq)):
                 ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k')
-                rise_idx_oi = rise_idx[id_idx][(rise_idx[id_idx] >= times_v_idx0) & (rise_idx[id_idx] <= times_v_idx1)]
+                rise_idx_oi = np.array(rise_idx[id_idx][(rise_idx[id_idx] >= times_v_idx0) & (rise_idx[id_idx] <= times_v_idx1)], dtype=int)
                 ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], marker='o', color='tab:red')
 
             plt.show()