refactoring finished for now
This commit is contained in:
@@ -1,5 +1,59 @@
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
|
||||
def instantaneous_frequency(
|
||||
signal: np.ndarray,
|
||||
samplerate: int,
|
||||
smoothing_window: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute the instantaneous frequency of a signal that is approximately
|
||||
sinusoidal and symmetric around 0.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : np.ndarray
|
||||
Signal to compute the instantaneous frequency from.
|
||||
samplerate : int
|
||||
Samplerate of the signal.
|
||||
smoothing_window : int
|
||||
Window size for the gaussian filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[np.ndarray, np.ndarray]
|
||||
|
||||
"""
|
||||
# calculate instantaneous frequency with zero crossings
|
||||
roll_signal = np.roll(signal, shift=1)
|
||||
time_signal = np.arange(len(signal)) / samplerate
|
||||
period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][
|
||||
1:-1
|
||||
]
|
||||
|
||||
upper_bound = np.abs(signal[period_index])
|
||||
lower_bound = np.abs(signal[period_index - 1])
|
||||
upper_time = np.abs(time_signal[period_index])
|
||||
lower_time = np.abs(time_signal[period_index - 1])
|
||||
|
||||
# create ratio
|
||||
lower_ratio = lower_bound / (lower_bound + upper_bound)
|
||||
|
||||
# appy to time delta
|
||||
time_delta = upper_time - lower_time
|
||||
true_zero = lower_time + lower_ratio * time_delta
|
||||
|
||||
# create new time array
|
||||
instantaneous_frequency_time = true_zero[:-1] + 0.5 * np.diff(true_zero)
|
||||
|
||||
# compute frequency
|
||||
instantaneous_frequency = gaussian_filter1d(
|
||||
1 / np.diff(true_zero), smoothing_window
|
||||
)
|
||||
|
||||
return instantaneous_frequency_time, instantaneous_frequency
|
||||
|
||||
|
||||
def purge_duplicates(
|
||||
@@ -64,7 +118,7 @@ def purge_duplicates(
|
||||
|
||||
|
||||
def group_timestamps(
|
||||
sublists: List[List[float]], n: int, threshold: float
|
||||
sublists: List[List[float]], at_least_in: int, difference_threshold: float
|
||||
) -> List[float]:
|
||||
"""
|
||||
Groups timestamps that are less than `threshold` milliseconds apart from
|
||||
@@ -100,7 +154,7 @@ def group_timestamps(
|
||||
|
||||
# Group timestamps that are less than threshold milliseconds apart
|
||||
for i in range(1, len(timestamps)):
|
||||
if timestamps[i] - timestamps[i - 1] < threshold:
|
||||
if timestamps[i] - timestamps[i - 1] < difference_threshold:
|
||||
current_group.append(timestamps[i])
|
||||
else:
|
||||
groups.append(current_group)
|
||||
@@ -111,7 +165,7 @@ def group_timestamps(
|
||||
# Retain only groups that contain at least n timestamps
|
||||
final_groups = []
|
||||
for group in groups:
|
||||
if len(group) >= n:
|
||||
if len(group) >= at_least_in:
|
||||
final_groups.append(group)
|
||||
|
||||
# Calculate the mean of each group
|
||||
|
||||
@@ -3,8 +3,8 @@ import numpy as np
|
||||
|
||||
|
||||
def bandpass_filter(
|
||||
data: np.ndarray,
|
||||
rate: float,
|
||||
signal: np.ndarray,
|
||||
samplerate: float,
|
||||
lowf: float,
|
||||
highf: float,
|
||||
) -> np.ndarray:
|
||||
@@ -12,7 +12,7 @@ def bandpass_filter(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray
|
||||
signal : np.ndarray
|
||||
The data to be filtered
|
||||
rate : float
|
||||
The sampling rate
|
||||
@@ -26,21 +26,22 @@ def bandpass_filter(
|
||||
np.ndarray
|
||||
The filtered data
|
||||
"""
|
||||
sos = butter(2, (lowf, highf), "bandpass", fs=rate, output="sos")
|
||||
fdata = sosfiltfilt(sos, data)
|
||||
return fdata
|
||||
sos = butter(2, (lowf, highf), "bandpass", fs=samplerate, output="sos")
|
||||
filtered_signal = sosfiltfilt(sos, signal)
|
||||
|
||||
return filtered_signal
|
||||
|
||||
|
||||
def highpass_filter(
|
||||
data: np.ndarray,
|
||||
rate: float,
|
||||
signal: np.ndarray,
|
||||
samplerate: float,
|
||||
cutoff: float,
|
||||
) -> np.ndarray:
|
||||
"""Highpass filter a signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray
|
||||
signal : np.ndarray
|
||||
The data to be filtered
|
||||
rate : float
|
||||
The sampling rate
|
||||
@@ -52,14 +53,15 @@ def highpass_filter(
|
||||
np.ndarray
|
||||
The filtered data
|
||||
"""
|
||||
sos = butter(2, cutoff, "highpass", fs=rate, output="sos")
|
||||
fdata = sosfiltfilt(sos, data)
|
||||
return fdata
|
||||
sos = butter(2, cutoff, "highpass", fs=samplerate, output="sos")
|
||||
filtered_signal = sosfiltfilt(sos, signal)
|
||||
|
||||
return filtered_signal
|
||||
|
||||
|
||||
def lowpass_filter(
|
||||
data: np.ndarray,
|
||||
rate: float,
|
||||
signal: np.ndarray,
|
||||
samplerate: float,
|
||||
cutoff: float
|
||||
) -> np.ndarray:
|
||||
"""Lowpass filter a signal.
|
||||
@@ -78,21 +80,25 @@ def lowpass_filter(
|
||||
np.ndarray
|
||||
The filtered data
|
||||
"""
|
||||
sos = butter(2, cutoff, "lowpass", fs=rate, output="sos")
|
||||
fdata = sosfiltfilt(sos, data)
|
||||
return fdata
|
||||
sos = butter(2, cutoff, "lowpass", fs=samplerate, output="sos")
|
||||
filtered_signal = sosfiltfilt(sos, signal)
|
||||
|
||||
return filtered_signal
|
||||
|
||||
|
||||
def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray:
|
||||
def envelope(signal: np.ndarray,
|
||||
samplerate: float,
|
||||
cutoff_frequency: float
|
||||
) -> np.ndarray:
|
||||
"""Calculate the envelope of a signal using a lowpass filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray
|
||||
signal : np.ndarray
|
||||
The signal to calculate the envelope of
|
||||
rate : float
|
||||
samplingrate : float
|
||||
The sampling rate of the signal
|
||||
freq : float
|
||||
cutoff_frequency : float
|
||||
The cutoff frequency of the lowpass filter
|
||||
|
||||
Returns
|
||||
@@ -100,6 +106,7 @@ def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray:
|
||||
np.ndarray
|
||||
The envelope of the signal
|
||||
"""
|
||||
sos = butter(2, freq, "lowpass", fs=rate, output="sos")
|
||||
envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(data))
|
||||
sos = butter(2, cutoff_frequency, "lowpass", fs=samplerate, output="sos")
|
||||
envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(signal))
|
||||
|
||||
return envelope
|
||||
|
||||
Reference in New Issue
Block a user