diff --git a/code/chirpdetection.py b/code/chirpdetection.py index cb3831f..348851a 100755 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -43,10 +43,12 @@ class PlotBuffer: time: np.ndarray baseline: np.ndarray + baseline_envelope_unfiltered: np.ndarray baseline_envelope: np.ndarray baseline_peaks: np.ndarray search_frequency: float search: np.ndarray + search_envelope_unfiltered: np.ndarray search_envelope: np.ndarray search_peaks: np.ndarray @@ -92,92 +94,144 @@ class PlotBuffer: self.time = self.time - self.t0 self.frequency_time = self.frequency_time - self.t0 - chirps = np.ararray(chirps) - self.t0 + chirps = np.asarray(chirps) - self.t0 + self.t0_old = self.t0 self.t0 = 0 fig = plt.figure( - figsize=(16 / 2.54, 24 / 2.54), constrained_layout=True + figsize=(16 / 2.54, 20 / 2.54) ) - grid = gr.GridSpec( - 9, 1, figure=fig, height_ratios=[4, 0.5, 1, 1, 1, 0.5, 1, 1, 1] + gs0 = gr.GridSpec( + 6, 1, figure=fig, height_ratios=[1, 0.05, 1, 0.05, 1, 0.05] ) + gs1 = gs0[0].subgridspec(1, 1) + gs2 = gs0[2].subgridspec(3, 1) + gs3 = gs0[4].subgridspec(3, 1) + gs4 = gs0[5].subgridspec(1, 1) - ax0 = fig.add_subplot(grid[0, 0]) - ax1 = fig.add_subplot(grid[2, 0], sharex=ax0) - ax2 = fig.add_subplot(grid[3, 0], sharex=ax0) - ax3 = fig.add_subplot(grid[4, 0], sharex=ax0) - ax4 = fig.add_subplot(grid[6, 0], sharex=ax0) - ax5 = fig.add_subplot(grid[7, 0], sharex=ax0) - ax6 = fig.add_subplot(grid[8, 0], sharex=ax0) + ax0 = fig.add_subplot(gs1[0, 0]) + ax1 = fig.add_subplot(gs2[0, 0], sharex=ax0) + ax2 = fig.add_subplot(gs2[1, 0], sharex=ax0) + ax3 = fig.add_subplot(gs2[2, 0], sharex=ax0) + ax4 = fig.add_subplot(gs3[0, 0], sharex=ax0) + ax5 = fig.add_subplot(gs3[1, 0], sharex=ax0) + ax6 = fig.add_subplot(gs3[2, 0], sharex=ax0) + ax7 = fig.add_subplot(gs4[0, 0], sharex=ax0) + + # ax_leg = fig.add_subplot(gs0[1, 0]) + + waveform_scaler = 1000 # plot spectrogram - plot_spectrogram(ax0, data_oi, self.data.raw_rate, self.t0) - - ax0.fill_between( - np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate), - q50 - self.config.minimal_bandwidth / 2, - q50 + self.config.minimal_bandwidth / 2, - color=ps.black, - lw=0, - alpha=0.2, + _ = plot_spectrogram( + ax0, + data_oi, + self.data.raw_rate, + self.t0, + [np.max(self.frequency) - 200, np.max(self.frequency) + 200] ) - ax0.fill_between( - np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate), - search_lower, - search_upper, - color=ps.black, - lw=0, - alpha=0.2, - ) + # ax0.fill_between( + # np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate), + # q50 - self.config.minimal_bandwidth / 2, + # q50 + self.config.minimal_bandwidth / 2, + # color=ps.black, + # lw=1, + # ls="dashed", + # alpha=0.5, + # ) + + # ax0.fill_between( + # np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate), + # search_lower, + # search_upper, + # color=ps.black, + # lw=1, + # ls="dashed", + # alpha=0.5, + # ) + # ax0.axhline(q50, spec_times[0], spec_times[-1], + # color=ps.gblue1, lw=2, ls="dashed") + # ax0.axhline(q50 + self.search_frequency, + # spec_times[0], spec_times[-1], + # color=ps.gblue2, lw=2, ls="dashed") for chirp in chirps: ax0.scatter( - chirp, np.median(self.frequency), c=ps.black, marker="x" + chirp, np.median(self.frequency) + 150, c=ps.black, marker="v" ) # plot waveform of filtered signal - ax1.plot(self.time, norm(self.baseline)) + ax1.plot(self.time, self.baseline * waveform_scaler, + c=ps.gray, lw=2, alpha=0.5) + ax1.plot(self.time, self.baseline_envelope_unfiltered * + waveform_scaler, c=ps.gblue1, lw=2, label="baseline envelope") # plot waveform of filtered search signal - ax2.plot(self.time, norm(self.search)) + ax2.plot(self.time, self.search * waveform_scaler, + c=ps.gray, lw=2, alpha=0.5) + ax2.plot(self.time, self.search_envelope_unfiltered * + waveform_scaler, c=ps.gblue2, lw=2, label="search envelope") # plot baseline instantaneous frequency - ax3.plot(self.frequency_time, self.frequency) + ax3.plot(self.frequency_time, self.frequency, + c=ps.gblue3, lw=2, label="baseline inst. freq.") # plot filtered and rectified envelope - ax4.plot(self.time, self.baseline_envelope) + ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=2) ax4.scatter( (self.time)[self.baseline_peaks], self.baseline_envelope[self.baseline_peaks], c=ps.red, + zorder=10, ) # plot envelope of search signal - ax5.plot(self.time, self.search_envelope) + ax5.plot(self.time, self.search_envelope, c=ps.gblue2, lw=2) ax5.scatter( (self.time)[self.search_peaks], self.search_envelope[self.search_peaks], c=ps.red, + zorder=10, ) # plot filtered instantaneous frequency - ax6.plot(self.frequency_time, self.frequency_filtered) + ax6.plot(self.frequency_time, + self.frequency_filtered, c=ps.gblue3, lw=2) ax6.scatter( self.frequency_time[self.frequency_peaks], self.frequency_filtered[self.frequency_peaks], c=ps.red, + zorder=10, ) - ax0.set_ylim( - np.max(self.frequency) - 200, top=np.max(self.frequency) + 400 - ) - ax0.set_title("Spectrogram of raw data") - ax1.set_title("Extracted features") - ax4.set_title("Filtered and rectified features") - ax6.set_xlabel("Time [s]") - ax0.set_xlim(0, 5) + ax0.set_ylabel("frequency [Hz]") + ax1.set_ylabel("a.u.") + ax2.set_ylabel("a.u.") + ax3.set_ylabel("Hz") + ax5.set_ylabel("a.u.") + ax7.set_xlabel("time [s]") + + ps.hide_xax(ax0) + ps.hide_xax(ax1) + ps.hide_xax(ax2) + ps.hide_xax(ax3) + ps.hide_xax(ax4) + ps.hide_xax(ax5) + ps.hide_xax(ax6) + ps.hide_yax(ax7) + + ps.letter_subplots([ax0, ax1, ax4], xoffset=-0.21) + + ax7.set_xticks(np.arange(0, 5.5, 1)) + ax7.spines.bottom.set_bounds((0, 5)) + + ax0.set_ymargin(0) + plt.subplots_adjust(left=0.19, right=0.99, + top=0.98, bottom=0.08, hspace=0.15) + fig.align_labels() + ax0.autoscale(enable=True) if plot == "show": plt.show() @@ -187,13 +241,17 @@ class PlotBuffer: self.config.outputdir + self.data.datapath.split("/")[-2] + "/" ) - plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf") + plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf") plt.close() def plot_spectrogram( - axis, signal: np.ndarray, samplerate: float, window_start_seconds: float -) -> None: + axis, + signal: np.ndarray, + samplerate: float, + window_start_seconds: float, + ylims: list[float] +) -> np.ndarray: """ Plot a spectrogram of a signal. @@ -219,18 +277,24 @@ def plot_spectrogram( overlap_frac=0.5, ) + fmask = np.zeros(spec_freqs.shape, dtype=bool) + fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True + axis.imshow( - decibel(spec_power), + decibel(spec_power[fmask, :]), extent=[ spec_times[0] + window_start_seconds, spec_times[-1] + window_start_seconds, - spec_freqs[0], - spec_freqs[-1], + spec_freqs[fmask][0], + spec_freqs[fmask][-1], ], aspect="auto", origin="lower", interpolation="gaussian", + alpha=1, ) + axis.use_sticky_edges = False + return spec_times def extract_frequency_bands( @@ -477,16 +541,16 @@ def main(datapath: str, plot: str) -> None: raw_time = np.arange(data.raw.shape[0]) / data.raw_rate # good chirp times for data: 2022-06-02-10_00 - # window_start_seconds = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate - # window_duration_seconds = 60 * data.raw_rate + window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate + window_duration_index = 60 * data.raw_rate # t0 = 0 # dt = data.raw.shape[0] # window_start_seconds = (23495 + ((28336-23495)/3)) * data.raw_rate # window_duration_seconds = (28336 - 23495) * data.raw_rate - window_start_index = 0 - window_duration_index = data.raw.shape[0] + # window_start_index = 0 + # window_duration_index = data.raw.shape[0] # generate starting points of rolling window window_start_indices = np.arange( @@ -648,11 +712,12 @@ def main(datapath: str, plot: str) -> None: # band envelope correspond to troughs in the baseline envelope # during chirps - search_envelope = envelope( + search_envelope_unfiltered = envelope( signal=searchband, samplerate=data.raw_rate, cutoff_frequency=config.search_envelope_cutoff, ) + search_envelope = search_envelope_unfiltered # compute instantaneous frequency of the baseline band to find # anomalies during a chirp, i.e. a frequency jump upwards or @@ -706,8 +771,10 @@ def main(datapath: str, plot: str) -> None: ) current_raw_time = current_raw_time[no_edges] baselineband = baselineband[no_edges] + baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges] searchband = searchband[no_edges] baseline_envelope = baseline_envelope[no_edges] + search_envelope_unfiltered = search_envelope_unfiltered[no_edges] search_envelope = search_envelope[no_edges] # get instantaneous frequency withoup edges @@ -822,11 +889,13 @@ def main(datapath: str, plot: str) -> None: track_id=track_id, data=data, time=current_raw_time, + baseline_envelope_unfiltered=baseline_envelope_unfiltered, baseline=baselineband, baseline_envelope=baseline_envelope, baseline_peaks=baseline_peak_indices, search_frequency=search_frequency, search=searchband, + search_envelope_unfiltered=search_envelope_unfiltered, search_envelope=search_envelope, search_peaks=search_peak_indices, frequency_time=baseline_frequency_time, @@ -864,9 +933,8 @@ def main(datapath: str, plot: str) -> None: multiwindow_ids.append(track_id) logger.info( - "Found %d chirps for fish %d" - % len(multielectrode_chirps_validated), - track_id, + f"Found {len(multielectrode_chirps_validated)}" + f" chirps for fish {track_id} in this window!" ) # if chirps are detected and the plot flag is set, plot the # chirps, otheswise try to delete the buffer if it exists @@ -930,7 +998,7 @@ def main(datapath: str, plot: str) -> None: if __name__ == "__main__": # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/" - # datapath = "../data/2022-06-02-10_00/" + datapath = "../data/2022-06-02-10_00/" # datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/" - datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/" - main(datapath, plot="show") + # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/" + main(datapath, plot="save") diff --git a/code/modules/plotstyle.py b/code/modules/plotstyle.py index 9e382a7..2325f62 100644 --- a/code/modules/plotstyle.py +++ b/code/modules/plotstyle.py @@ -30,10 +30,14 @@ def PlotStyle() -> None: purple = "#cba6f7" pink = "#f5c2e7" lavender = "#b4befe" + gblue1 = "#8cb8ff" + gblue2 = "#7cdcdc" + gblue3 = "#82e896" @classmethod def lims(cls, track1, track2): - """Helper function to get frequency y axis limits from two fundamental frequency tracks. + """Helper function to get frequency y axis limits from two + fundamental frequency tracks. Args: track1 (array): First track @@ -91,6 +95,16 @@ def PlotStyle() -> None: ax.tick_params(left=False, labelleft=False) ax.patch.set_visible(False) + @classmethod + def hide_xax(cls, ax): + ax.xaxis.set_visible(False) + ax.spines["bottom"].set_visible(False) + + @classmethod + def hide_yax(cls, ax): + ax.yaxis.set_visible(False) + ax.spines["left"].set_visible(False) + @classmethod def set_boxplot_color(cls, bp, color): plt.setp(bp["boxes"], color=color) @@ -216,8 +230,8 @@ def PlotStyle() -> None: plt.rc("figure", titlesize=BIGGER_SIZE) # fontsize of the figure title plt.rcParams["image.cmap"] = 'cmo.haline' - # plt.rcParams["axes.xmargin"] = 0.1 - # plt.rcParams["axes.ymargin"] = 0.15 + plt.rcParams["axes.xmargin"] = 0.05 + plt.rcParams["axes.ymargin"] = 0.1 plt.rcParams["axes.titlelocation"] = "left" plt.rcParams["axes.titlesize"] = BIGGER_SIZE # plt.rcParams["axes.titlepad"] = -10 @@ -230,9 +244,9 @@ def PlotStyle() -> None: plt.rcParams["legend.borderaxespad"] = 0.5 plt.rcParams["legend.fancybox"] = False - # specify the custom font to use - plt.rcParams["font.family"] = "sans-serif" - plt.rcParams["font.sans-serif"] = "Helvetica Now Text" + # # specify the custom font to use + # plt.rcParams["font.family"] = "sans-serif" + # plt.rcParams["font.sans-serif"] = "Helvetica Now Text" # dark mode modifications plt.rcParams["boxplot.flierprops.color"] = white @@ -271,7 +285,7 @@ def PlotStyle() -> None: plt.rcParams["ytick.color"] = gray # color of the ticks plt.rcParams["grid.color"] = dark_gray # grid color plt.rcParams["figure.facecolor"] = black # figure face color - plt.rcParams["figure.edgecolor"] = "#555169" # figure edge color + plt.rcParams["figure.edgecolor"] = black # figure edge color plt.rcParams["savefig.facecolor"] = black # figure face color when saving return style