From 9a7ac66f4a2890c0fb755a1d7365f65a659972d7 Mon Sep 17 00:00:00 2001
From: weygoldt <88969563+weygoldt@users.noreply.github.com>
Date: Sun, 22 Jan 2023 00:52:22 +0100
Subject: [PATCH] plot looks better now

---
 code/chirpdetection.py    | 188 ++++++++++++++++++++++++++------------
 code/modules/plotstyle.py |  28 ++++--
 2 files changed, 149 insertions(+), 67 deletions(-)

diff --git a/code/chirpdetection.py b/code/chirpdetection.py
index cb3831f..348851a 100755
--- a/code/chirpdetection.py
+++ b/code/chirpdetection.py
@@ -43,10 +43,12 @@ class PlotBuffer:
 
     time: np.ndarray
     baseline: np.ndarray
+    baseline_envelope_unfiltered: np.ndarray
     baseline_envelope: np.ndarray
     baseline_peaks: np.ndarray
     search_frequency: float
     search: np.ndarray
+    search_envelope_unfiltered: np.ndarray
     search_envelope: np.ndarray
     search_peaks: np.ndarray
 
@@ -92,92 +94,144 @@ class PlotBuffer:
 
         self.time = self.time - self.t0
         self.frequency_time = self.frequency_time - self.t0
-        chirps = np.ararray(chirps) - self.t0
+        chirps = np.asarray(chirps) - self.t0
+        self.t0_old = self.t0
         self.t0 = 0
 
         fig = plt.figure(
-            figsize=(16 / 2.54, 24 / 2.54), constrained_layout=True
+            figsize=(16 / 2.54, 20 / 2.54)
         )
 
-        grid = gr.GridSpec(
-            9, 1, figure=fig, height_ratios=[4, 0.5, 1, 1, 1, 0.5, 1, 1, 1]
+        gs0 = gr.GridSpec(
+            6, 1, figure=fig, height_ratios=[1, 0.05, 1, 0.05, 1, 0.05]
         )
+        gs1 = gs0[0].subgridspec(1, 1)
+        gs2 = gs0[2].subgridspec(3, 1)
+        gs3 = gs0[4].subgridspec(3, 1)
+        gs4 = gs0[5].subgridspec(1, 1)
 
-        ax0 = fig.add_subplot(grid[0, 0])
-        ax1 = fig.add_subplot(grid[2, 0], sharex=ax0)
-        ax2 = fig.add_subplot(grid[3, 0], sharex=ax0)
-        ax3 = fig.add_subplot(grid[4, 0], sharex=ax0)
-        ax4 = fig.add_subplot(grid[6, 0], sharex=ax0)
-        ax5 = fig.add_subplot(grid[7, 0], sharex=ax0)
-        ax6 = fig.add_subplot(grid[8, 0], sharex=ax0)
+        ax0 = fig.add_subplot(gs1[0, 0])
+        ax1 = fig.add_subplot(gs2[0, 0], sharex=ax0)
+        ax2 = fig.add_subplot(gs2[1, 0], sharex=ax0)
+        ax3 = fig.add_subplot(gs2[2, 0], sharex=ax0)
+        ax4 = fig.add_subplot(gs3[0, 0], sharex=ax0)
+        ax5 = fig.add_subplot(gs3[1, 0], sharex=ax0)
+        ax6 = fig.add_subplot(gs3[2, 0], sharex=ax0)
+        ax7 = fig.add_subplot(gs4[0, 0], sharex=ax0)
+
+        # ax_leg = fig.add_subplot(gs0[1, 0])
+
+        waveform_scaler = 1000
 
         # plot spectrogram
-        plot_spectrogram(ax0, data_oi, self.data.raw_rate, self.t0)
-
-        ax0.fill_between(
-            np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
-            q50 - self.config.minimal_bandwidth / 2,
-            q50 + self.config.minimal_bandwidth / 2,
-            color=ps.black,
-            lw=0,
-            alpha=0.2,
+        _ = plot_spectrogram(
+            ax0,
+            data_oi,
+            self.data.raw_rate,
+            self.t0,
+            [np.max(self.frequency) - 200, np.max(self.frequency) + 200]
         )
 
-        ax0.fill_between(
-            np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
-            search_lower,
-            search_upper,
-            color=ps.black,
-            lw=0,
-            alpha=0.2,
-        )
+        # ax0.fill_between(
+        #     np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
+        #     q50 - self.config.minimal_bandwidth / 2,
+        #     q50 + self.config.minimal_bandwidth / 2,
+        #     color=ps.black,
+        #     lw=1,
+        #     ls="dashed",
+        #     alpha=0.5,
+        # )
+
+        # ax0.fill_between(
+        #     np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
+        #     search_lower,
+        #     search_upper,
+        #     color=ps.black,
+        #     lw=1,
+        #     ls="dashed",
+        #     alpha=0.5,
+        # )
+        # ax0.axhline(q50, spec_times[0], spec_times[-1],
+        #             color=ps.gblue1, lw=2, ls="dashed")
+        # ax0.axhline(q50 + self.search_frequency,
+        #             spec_times[0], spec_times[-1],
+        #             color=ps.gblue2, lw=2, ls="dashed")
 
         for chirp in chirps:
             ax0.scatter(
-                chirp, np.median(self.frequency), c=ps.black, marker="x"
+                chirp, np.median(self.frequency) + 150, c=ps.black, marker="v"
             )
 
         # plot waveform of filtered signal
-        ax1.plot(self.time, norm(self.baseline))
+        ax1.plot(self.time, self.baseline * waveform_scaler,
+                 c=ps.gray, lw=2, alpha=0.5)
+        ax1.plot(self.time, self.baseline_envelope_unfiltered *
+                 waveform_scaler, c=ps.gblue1, lw=2, label="baseline envelope")
 
         # plot waveform of filtered search signal
-        ax2.plot(self.time, norm(self.search))
+        ax2.plot(self.time, self.search * waveform_scaler,
+                 c=ps.gray, lw=2, alpha=0.5)
+        ax2.plot(self.time, self.search_envelope_unfiltered *
+                 waveform_scaler, c=ps.gblue2, lw=2, label="search envelope")
 
         # plot baseline instantaneous frequency
-        ax3.plot(self.frequency_time, self.frequency)
+        ax3.plot(self.frequency_time, self.frequency,
+                 c=ps.gblue3, lw=2, label="baseline inst. freq.")
 
         # plot filtered and rectified envelope
-        ax4.plot(self.time, self.baseline_envelope)
+        ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=2)
         ax4.scatter(
             (self.time)[self.baseline_peaks],
             self.baseline_envelope[self.baseline_peaks],
             c=ps.red,
+            zorder=10,
         )
 
         # plot envelope of search signal
-        ax5.plot(self.time, self.search_envelope)
+        ax5.plot(self.time, self.search_envelope, c=ps.gblue2, lw=2)
         ax5.scatter(
             (self.time)[self.search_peaks],
             self.search_envelope[self.search_peaks],
             c=ps.red,
+            zorder=10,
         )
 
         # plot filtered instantaneous frequency
-        ax6.plot(self.frequency_time, self.frequency_filtered)
+        ax6.plot(self.frequency_time,
+                 self.frequency_filtered, c=ps.gblue3, lw=2)
         ax6.scatter(
             self.frequency_time[self.frequency_peaks],
             self.frequency_filtered[self.frequency_peaks],
             c=ps.red,
+            zorder=10,
         )
-        ax0.set_ylim(
-            np.max(self.frequency) - 200, top=np.max(self.frequency) + 400
-        )
-        ax0.set_title("Spectrogram of raw data")
-        ax1.set_title("Extracted features")
-        ax4.set_title("Filtered and rectified features")
-        ax6.set_xlabel("Time [s]")
 
-        ax0.set_xlim(0, 5)
+        ax0.set_ylabel("frequency [Hz]")
+        ax1.set_ylabel("a.u.")
+        ax2.set_ylabel("a.u.")
+        ax3.set_ylabel("Hz")
+        ax5.set_ylabel("a.u.")
+        ax7.set_xlabel("time [s]")
+
+        ps.hide_xax(ax0)
+        ps.hide_xax(ax1)
+        ps.hide_xax(ax2)
+        ps.hide_xax(ax3)
+        ps.hide_xax(ax4)
+        ps.hide_xax(ax5)
+        ps.hide_xax(ax6)
+        ps.hide_yax(ax7)
+
+        ps.letter_subplots([ax0, ax1, ax4], xoffset=-0.21)
+
+        ax7.set_xticks(np.arange(0, 5.5, 1))
+        ax7.spines.bottom.set_bounds((0, 5))
+
+        ax0.set_ymargin(0)
+        plt.subplots_adjust(left=0.19, right=0.99,
+                            top=0.98, bottom=0.08, hspace=0.15)
+        fig.align_labels()
+        ax0.autoscale(enable=True)
 
         if plot == "show":
             plt.show()
@@ -187,13 +241,17 @@ class PlotBuffer:
                 self.config.outputdir + self.data.datapath.split("/")[-2] + "/"
             )
 
-            plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf")
+            plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf")
             plt.close()
 
 
 def plot_spectrogram(
-    axis, signal: np.ndarray, samplerate: float, window_start_seconds: float
-) -> None:
+    axis,
+    signal: np.ndarray,
+    samplerate: float,
+    window_start_seconds: float,
+    ylims: list[float]
+) -> np.ndarray:
     """
     Plot a spectrogram of a signal.
 
@@ -219,18 +277,24 @@ def plot_spectrogram(
         overlap_frac=0.5,
     )
 
+    fmask = np.zeros(spec_freqs.shape, dtype=bool)
+    fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
+
     axis.imshow(
-        decibel(spec_power),
+        decibel(spec_power[fmask, :]),
         extent=[
             spec_times[0] + window_start_seconds,
             spec_times[-1] + window_start_seconds,
-            spec_freqs[0],
-            spec_freqs[-1],
+            spec_freqs[fmask][0],
+            spec_freqs[fmask][-1],
         ],
         aspect="auto",
         origin="lower",
         interpolation="gaussian",
+        alpha=1,
     )
+    axis.use_sticky_edges = False
+    return spec_times
 
 
 def extract_frequency_bands(
@@ -477,16 +541,16 @@ def main(datapath: str, plot: str) -> None:
     raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
 
     # good chirp times for data: 2022-06-02-10_00
-    # window_start_seconds = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
-    # window_duration_seconds = 60 * data.raw_rate
+    window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
+    window_duration_index = 60 * data.raw_rate
 
     #     t0 = 0
     #     dt = data.raw.shape[0]
     # window_start_seconds = (23495 + ((28336-23495)/3)) * data.raw_rate
     # window_duration_seconds = (28336 - 23495) * data.raw_rate
 
-    window_start_index = 0
-    window_duration_index = data.raw.shape[0]
+    # window_start_index = 0
+    # window_duration_index = data.raw.shape[0]
 
     # generate starting points of rolling window
     window_start_indices = np.arange(
@@ -648,11 +712,12 @@ def main(datapath: str, plot: str) -> None:
                 # band envelope correspond to troughs in the baseline envelope
                 # during chirps
 
-                search_envelope = envelope(
+                search_envelope_unfiltered = envelope(
                     signal=searchband,
                     samplerate=data.raw_rate,
                     cutoff_frequency=config.search_envelope_cutoff,
                 )
+                search_envelope = search_envelope_unfiltered
 
                 # compute instantaneous frequency of the baseline band to find
                 # anomalies during a chirp, i.e. a frequency jump upwards or
@@ -706,8 +771,10 @@ def main(datapath: str, plot: str) -> None:
                 )
                 current_raw_time = current_raw_time[no_edges]
                 baselineband = baselineband[no_edges]
+                baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges]
                 searchband = searchband[no_edges]
                 baseline_envelope = baseline_envelope[no_edges]
+                search_envelope_unfiltered = search_envelope_unfiltered[no_edges]
                 search_envelope = search_envelope[no_edges]
 
                 # get instantaneous frequency withoup edges
@@ -822,11 +889,13 @@ def main(datapath: str, plot: str) -> None:
                         track_id=track_id,
                         data=data,
                         time=current_raw_time,
+                        baseline_envelope_unfiltered=baseline_envelope_unfiltered,
                         baseline=baselineband,
                         baseline_envelope=baseline_envelope,
                         baseline_peaks=baseline_peak_indices,
                         search_frequency=search_frequency,
                         search=searchband,
+                        search_envelope_unfiltered=search_envelope_unfiltered,
                         search_envelope=search_envelope,
                         search_peaks=search_peak_indices,
                         frequency_time=baseline_frequency_time,
@@ -864,9 +933,8 @@ def main(datapath: str, plot: str) -> None:
             multiwindow_ids.append(track_id)
 
             logger.info(
-                "Found %d chirps for fish %d"
-                % len(multielectrode_chirps_validated),
-                track_id,
+                f"Found {len(multielectrode_chirps_validated)}"
+                f" chirps for fish {track_id} in this window!"
             )
             # if chirps are detected and the plot flag is set, plot the
             # chirps, otheswise try to delete the buffer if it exists
@@ -930,7 +998,7 @@ def main(datapath: str, plot: str) -> None:
 
 if __name__ == "__main__":
     # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/"
-    # datapath = "../data/2022-06-02-10_00/"
+    datapath = "../data/2022-06-02-10_00/"
     # datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/"
-    datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/"
-    main(datapath, plot="show")
+    # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/"
+    main(datapath, plot="save")
diff --git a/code/modules/plotstyle.py b/code/modules/plotstyle.py
index 9e382a7..2325f62 100644
--- a/code/modules/plotstyle.py
+++ b/code/modules/plotstyle.py
@@ -30,10 +30,14 @@ def PlotStyle() -> None:
         purple = "#cba6f7"
         pink = "#f5c2e7"
         lavender = "#b4befe"
+        gblue1 = "#8cb8ff"
+        gblue2 = "#7cdcdc"
+        gblue3 = "#82e896"
 
         @classmethod
         def lims(cls, track1, track2):
-            """Helper function to get frequency y axis limits from two fundamental frequency tracks.
+            """Helper function to get frequency y axis limits from two
+            fundamental frequency tracks.
 
             Args:
                 track1 (array): First track
@@ -91,6 +95,16 @@ def PlotStyle() -> None:
             ax.tick_params(left=False, labelleft=False)
             ax.patch.set_visible(False)
 
+        @classmethod
+        def hide_xax(cls, ax):
+            ax.xaxis.set_visible(False)
+            ax.spines["bottom"].set_visible(False)
+
+        @classmethod
+        def hide_yax(cls, ax):
+            ax.yaxis.set_visible(False)
+            ax.spines["left"].set_visible(False)
+
         @classmethod
         def set_boxplot_color(cls, bp, color):
             plt.setp(bp["boxes"], color=color)
@@ -216,8 +230,8 @@ def PlotStyle() -> None:
     plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title
 
     plt.rcParams["image.cmap"] = 'cmo.haline'
-    # plt.rcParams["axes.xmargin"] = 0.1
-    # plt.rcParams["axes.ymargin"] = 0.15
+    plt.rcParams["axes.xmargin"] = 0.05
+    plt.rcParams["axes.ymargin"] = 0.1
     plt.rcParams["axes.titlelocation"] = "left"
     plt.rcParams["axes.titlesize"] = BIGGER_SIZE
     # plt.rcParams["axes.titlepad"] = -10
@@ -230,9 +244,9 @@ def PlotStyle() -> None:
     plt.rcParams["legend.borderaxespad"] = 0.5
     plt.rcParams["legend.fancybox"] = False
 
-    # specify the custom font to use
-    plt.rcParams["font.family"] = "sans-serif"
-    plt.rcParams["font.sans-serif"] = "Helvetica Now Text"
+    # # specify the custom font to use
+    # plt.rcParams["font.family"] = "sans-serif"
+    # plt.rcParams["font.sans-serif"] = "Helvetica Now Text"
 
     # dark mode modifications
     plt.rcParams["boxplot.flierprops.color"] = white
@@ -271,7 +285,7 @@ def PlotStyle() -> None:
     plt.rcParams["ytick.color"] = gray  # color of the ticks
     plt.rcParams["grid.color"] = dark_gray  # grid color
     plt.rcParams["figure.facecolor"] = black    # figure face color
-    plt.rcParams["figure.edgecolor"] = "#555169"   # figure edge color
+    plt.rcParams["figure.edgecolor"] = black   # figure edge color
     plt.rcParams["savefig.facecolor"] = black  # figure face color when saving
 
     return style