added docstrings

This commit is contained in:
weygoldt 2023-01-11 22:38:11 +01:00
parent e174979e9c
commit 21ff35c6ac

View File

@ -9,14 +9,56 @@ from scipy.ndimage import gaussian_filter1d
from thunderfish.dataloader import DataLoader from thunderfish.dataloader import DataLoader
from thunderfish.powerspectrum import spectrogram, decibel from thunderfish.powerspectrum import spectrogram, decibel
from modules.filters import bandpass_filter, envelope, highpass_filter from modules.filters import bandpass_filter, envelope, highpass_filter
class LoadData:
"""
Attributes
----------
data : DataLoader object containing raw data
samplerate : sampling rate of raw data
time : array of time for tracked fundamental frequency
freq : array of fundamental frequency
idx : array of indices to access time array
ident : array of identifiers for each tracked fundamental frequency
ids : array of unique identifiers exluding NaNs
"""
def __init__(self, datapath: str) -> None:
# load raw data
file = os.path.join(datapath, "traces-grid1.raw")
self.data = DataLoader(file, 60.0, 0, channel=-1)
self.samplerate = self.data.samplerate
# load wavetracker files
self.time = np.load(datapath + "times.npy", allow_pickle=True)
self.freq = np.load(datapath + "fund_v.npy", allow_pickle=True)
self.idx = np.load(datapath + "idx_v.npy", allow_pickle=True)
self.ident = np.load(datapath + "ident_v.npy", allow_pickle=True)
self.ids = np.unique(self.ident[~np.isnan(self.ident)])
def instantaneos_frequency( def instantaneos_frequency(
signal: np.ndarray, samplerate: int signal: np.ndarray, samplerate: int
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the instantaneous frequency of a signal.
Parameters
----------
signal : np.ndarray
Signal to compute the instantaneous frequency from.
samplerate : int
Samplerate of the signal.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# calculate instantaneos frequency with zero crossings # calculate instantaneos frequency with zero crossings
roll_signal = np.roll(signal, shift=1) roll_signal = np.roll(signal, shift=1)
time_signal = np.arange(len(signal)) / samplerate time_signal = np.arange(len(signal)) / samplerate
@ -44,7 +86,19 @@ def instantaneos_frequency(
def plot_spectrogram(axis, signal: np.ndarray, samplerate: float) -> None: def plot_spectrogram(axis, signal: np.ndarray, samplerate: float) -> None:
"""
Plot a spectrogram of a signal.
Parameters
----------
axis : matplotlib axis
Axis to plot the spectrogram on.
signal : np.ndarray
Signal to plot the spectrogram from.
samplerate : float
Samplerate of the signal.
"""
# compute spectrogram # compute spectrogram
spec_power, spec_freqs, spec_times = spectrogram( spec_power, spec_freqs, spec_times = spectrogram(
signal, signal,
@ -65,7 +119,26 @@ def plot_spectrogram(axis, signal: np.ndarray, samplerate: float) -> None:
def double_bandpass( def double_bandpass(
data: DataLoader, samplerate: int, freqs: np.ndarray, search_freq: float data: DataLoader, samplerate: int, freqs: np.ndarray, search_freq: float
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
"""
Apply a bandpass filter to the baseline of a signal and a second bandpass
filter above or below the baseline.
Parameters
----------
data : DataLoader
Data to apply the filter to.
samplerate : int
Samplerate of the signal.
freqs : np.ndarray
Tracked fundamental frequencies of the signal.
search_freq : float
Frequency to search for above or below the baseline.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# compute boundaries to filter baseline # compute boundaries to filter baseline
q25, q75 = np.percentile(freqs, [25, 75]) q25, q75 = np.percentile(freqs, [25, 75])
@ -98,6 +171,9 @@ def main(datapath: str) -> None:
ident = np.load(datapath + "ident_v.npy", allow_pickle=True) ident = np.load(datapath + "ident_v.npy", allow_pickle=True)
# set time window # <------------------------ Iterate through windows here # set time window # <------------------------ Iterate through windows here
window_duration = 60 * data.samplerate
window_overlap = 0.3
t0 = 3 * 60 * 60 + 6 * 60 + 43.5 t0 = 3 * 60 * 60 + 6 * 60 + 43.5
dt = 60 dt = 60
start_index = t0 * data.samplerate start_index = t0 * data.samplerate
@ -216,13 +292,15 @@ def main(datapath: str) -> None:
np.arange(len(baseline)) / data.samplerate, baseline_envelope np.arange(len(baseline)) / data.samplerate, baseline_envelope
) )
axs[5].plot(np.arange(len(baseline)) / data.samplerate, search_envelope) axs[5].plot(np.arange(len(baseline)) /
data.samplerate, search_envelope)
axs[6].plot(baseline_freq_time, np.abs(inst_freq_filtered)) axs[6].plot(baseline_freq_time, np.abs(inst_freq_filtered))
# detect peaks baseline_enelope # detect peaks baseline_enelope
prominence = iqr(baseline_envelope) prominence = iqr(baseline_envelope)
baseline_peaks, _ = find_peaks(baseline_envelope, prominence=prominence) baseline_peaks, _ = find_peaks(
baseline_envelope, prominence=prominence)
axs[4].scatter( axs[4].scatter(
(np.arange(len(baseline)) / data.samplerate)[baseline_peaks], (np.arange(len(baseline)) / data.samplerate)[baseline_peaks],
baseline_envelope[baseline_peaks], baseline_envelope[baseline_peaks],
@ -245,8 +323,6 @@ def main(datapath: str) -> None:
c="red", c="red",
) )
#
axs[0].set_title("Spectrogram") axs[0].set_title("Spectrogram")
axs[1].set_title("Fitered baseline instanenous frequency") axs[1].set_title("Fitered baseline instanenous frequency")
axs[2].set_title("Fitered baseline") axs[2].set_title("Fitered baseline")