GP2023_chirp_detection/chirp_instantaneous_freq/test_parameters.py
2023-04-13 15:14:10 +02:00

151 lines
4.4 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
from filters import bandpass_filter, inst_freq, instantaneous_frequency
from fish_signal import chirps, wavefish_eods
from IPython import embed
def switch_test(test, defaultparams, testparams):
if test == "width":
defaultparams["chirp_width"] = testparams["chirp_width"]
key = "chirp_width"
elif test == "size":
defaultparams["chirp_size"] = testparams["chirp_size"]
key = "chirp_size"
elif test == "kurtosis":
defaultparams["chirp_kurtosis"] = testparams["chirp_kurtosis"]
key = "chirp_kurtosis"
elif test == "contrast":
defaultparams["chirp_contrast"] = testparams["chirp_contrast"]
key = "chirp_contrast"
else:
raise ValueError("Test not recognized")
return key, defaultparams
def extract_dict(dict, index):
return {key: value[index] for key, value in dict.items()}
def test(test1, test2, resolution=10):
assert test1 in [
"width",
"size",
"kurtosis",
"contrast",
], "Test1 not recognized"
assert test2 in [
"width",
"size",
"kurtosis",
"contrast",
], "Test2 not recognized"
# Define the parameters for the chirp simulations
ntest = resolution
defaultparams = dict(
chirp_size=np.ones(ntest) * 100,
chirp_width=np.ones(ntest) * 0.1,
chirp_kurtosis=np.ones(ntest) * 1.0,
chirp_contrast=np.ones(ntest) * 0.5,
)
testparams = dict(
chirp_width=np.linspace(0.01, 0.2, ntest),
chirp_size=np.linspace(50, 300, ntest),
chirp_kurtosis=np.linspace(0.5, 1.5, ntest),
chirp_contrast=np.linspace(0.01, 1.0, ntest),
)
key1, chirp_params = switch_test(test1, defaultparams, testparams)
key2, chirp_params = switch_test(test2, chirp_params, testparams)
# make the chirp trace
eodf = 500
samplerate = 20000
duration = 2
chirp_times = [0.5, 1, 1.5]
wide_cutoffs = 200
tight_cutoffs = 10
distances = np.full((ntest, ntest), np.nan)
fig, axs = plt.subplots(
ntest, ntest, figsize=(10, 10), sharex=True, sharey=True
)
axs = axs.flatten()
iter0 = 0
for iter1, test1_param in enumerate(chirp_params[key1]):
for iter2, test2_param in enumerate(chirp_params[key2]):
# get the chirp parameters for the current test
inner_chirp_params = extract_dict(chirp_params, iter2)
inner_chirp_params[key1] = test1_param
inner_chirp_params[key2] = test2_param
# make the chirp trace for the current chirp parameters
sizes = np.ones(len(chirp_times)) * inner_chirp_params["chirp_size"]
widths = (
np.ones(len(chirp_times)) * inner_chirp_params["chirp_width"]
)
kurtosis = (
np.ones(len(chirp_times)) * inner_chirp_params["chirp_kurtosis"]
)
contrast = (
np.ones(len(chirp_times)) * inner_chirp_params["chirp_contrast"]
)
# make the chirp trace
chirp_trace, ampmod = chirps(
eodf,
samplerate,
duration,
chirp_times,
sizes,
widths,
kurtosis,
contrast,
)
signal = wavefish_eods(
fish="Alepto",
frequency=chirp_trace,
samplerate=samplerate,
duration=duration,
phase0=0.0,
noise_std=0.05,
)
signal = signal * ampmod
# apply broadband filter
wide_signal = bandpass_filter(
signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs
)
tight_signal = bandpass_filter(
signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs
)
# get the instantaneous frequency
wide_frequency = inst_freq(wide_signal, samplerate)
tight_frequency = inst_freq(tight_signal, samplerate)
bool_mask = wide_frequency != 0
axs[iter0].plot(wide_frequency[bool_mask])
axs[iter0].plot(tight_frequency[bool_mask])
fig.supylabel(key1)
fig.supxlabel(key2)
iter0 += 1
plt.show()
def main():
test("contrast", "kurtosis")
if __name__ == "__main__":
main()