import numpy as np import matplotlib.pyplot as plt from fish_signal import chirps, wavefish_eods from filters import bandpass_filter, instantaneous_frequency, inst_freq from IPython import embed def switch_test(test, defaultparams, testparams): if test == 'width': defaultparams['chirp_width'] = testparams['chirp_width'] key = 'chirp_width' elif test == 'size': defaultparams['chirp_size'] = testparams['chirp_size'] key = 'chirp_size' elif test == 'kurtosis': defaultparams['chirp_kurtosis'] = testparams['chirp_kurtosis'] key = 'chirp_kurtosis' elif test == 'contrast': defaultparams['chirp_contrast'] = testparams['chirp_contrast'] key = 'chirp_contrast' else: raise ValueError("Test not recognized") return key, defaultparams def extract_dict(dict, index): return {key: value[index] for key, value in dict.items()} def main(test1, test2, resolution=10): assert test1 in ['width', 'size', 'kurtosis', 'contrast'], "Test1 not recognized" assert test2 in ['width', 'size', 'kurtosis', 'contrast'], "Test2 not recognized" # Define the parameters for the chirp simulations ntest = resolution defaultparams = dict( chirp_size = np.ones(ntest) * 100, chirp_width = np.ones(ntest) * 0.1, chirp_kurtosis = np.ones(ntest) * 1.0, chirp_contrast = np.ones(ntest) * 0.5, ) testparams = dict( chirp_width = np.linspace(0.01, 0.2, ntest), chirp_size = np.linspace(50, 300, ntest), chirp_kurtosis = np.linspace(0.5, 1.5, ntest), chirp_contrast = np.linspace(0.01, 1.0, ntest), ) key1, chirp_params = switch_test(test1, defaultparams, testparams) key2, chirp_params = switch_test(test2, chirp_params, testparams) # make the chirp trace eodf = 500 samplerate = 20000 duration = 2 chirp_times = [0.5, 1, 1.5] wide_cutoffs = 200 tight_cutoffs = 10 distances = np.full((ntest, ntest), np.nan) fig, axs = plt.subplots(ntest, ntest, figsize = (10, 10), sharex = True, sharey = True) axs = axs.flatten() iter0 = 0 for iter1, test1_param in enumerate(chirp_params[key1]): for iter2, test2_param in enumerate(chirp_params[key2]): # get the chirp parameters for the current test inner_chirp_params = extract_dict(chirp_params, iter2) inner_chirp_params[key1] = test1_param inner_chirp_params[key2] = test2_param # make the chirp trace for the current chirp parameters sizes = np.ones(len(chirp_times)) * inner_chirp_params['chirp_size'] widths = np.ones(len(chirp_times)) * inner_chirp_params['chirp_width'] kurtosis = np.ones(len(chirp_times)) * inner_chirp_params['chirp_kurtosis'] contrast = np.ones(len(chirp_times)) * inner_chirp_params['chirp_contrast'] # make the chirp trace chirp_trace, ampmod = chirps(eodf, samplerate, duration, chirp_times, sizes, widths, kurtosis, contrast) signal = wavefish_eods( fish="Alepto", frequency=chirp_trace, samplerate=samplerate, duration=duration, phase0=0.0, noise_std=0.05 ) signal = signal * ampmod # apply broadband filter wide_signal = bandpass_filter(signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs) tight_signal = bandpass_filter(signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs) # get the instantaneous frequency wide_frequency = inst_freq(wide_signal, samplerate) tight_frequency = inst_freq(tight_signal, samplerate) bool_mask = wide_frequency != 0 axs[iter0].plot(wide_frequency[bool_mask]) axs[iter0].plot(tight_frequency[bool_mask]) fig.supylabel(key1) fig.supxlabel(key2) iter0 += 1 fig, ax = plt.subplots() ax.imshow(distances, cmap = 'jet') plt.show() if __name__ == "__main__": main('width', 'size')