nonlinearbaseline2025/lifsuscept.py

156 lines
5.6 KiB
Python

import numpy as np
import mpmath as mp
import matplotlib.pyplot as plt
from pathlib import Path
from plotstyle import plot_style
sims_path = Path('data') / 'simulations'
"""
LIF code from Maria Schlungbaum, Lidner lab, 2024
LIF model in dimensionless units: dv/dt = -v + mu + sqrt(2D)*xi
v: membrane voltage
mu: mean input voltage
D: noise intensity
xi: white Gaussian noise
tau_mem = 1 (membrane time constant, skipped here)
tau_ref: refractory period
vT: threshold voltage
vR: reset voltage
"""
def firingrate(mu, D, tau_ref, vR, vT):
x_start = (mu - vT)/mp.sqrt(2.0*D)
x_end = (mu - vR)/mp.sqrt(2.0*D)
dx = 0.0001
r = 0.0
for i in np.arange(x_start, x_end, dx):
integrand = mp.exp(i**2) * mp.erfc(i)
r += integrand*dx
r0 = 1.0/(tau_ref + mp.sqrt(mp.pi)*r)
return float(r0)
def susceptibility1(omega, r0, mu, D, tau_ref, vR, vT):
delta = (vR**2 - vT**2 + 2.0*mu*(vT - vR))/(4.0*D)
a = (r0 * omega*1.0j)/(mp.sqrt(D) * (omega*1.0j - 1.0))
b = mp.pcfd(omega*1.0j - 1.0, (mu - vT)/mp.sqrt(D)) - mp.exp(delta) * mp.pcfd(omega*1.0j - 1.0, (mu - vR)/mp.sqrt(D))
c = mp.pcfd(omega*1.0j, (mu - vT)/mp.sqrt(D)) - mp.exp(delta) * mp.exp(omega*1.0j*tau_ref) * mp.pcfd(omega*1.0j, (mu - vR)/mp.sqrt(D))
return a * b/c
def susceptibility2(omega1, omega2, chi1_1, chi1_2, r0, mu, D, tau_ref, vR, vT):
delta = (vR**2 - vT**2 + 2.0*mu*(vT - vR))/(4.0*D)
a1 = r0*(1.0 - omega1*1.0j - omega2*1.0j)*(omega1*1.0j + omega2*1.0j)/(2.0*D*(omega1*1.0j - 1.0)*(omega2*1.0j - 1.0))
a2 = (omega1*1.0j + omega2*1.0j)/(2.0*mp.sqrt(D))
a3 = chi1_1/(omega2*1.0j - 1.0) + chi1_2/(omega1*1.0j - 1.0)
b1 = mp.pcfd(omega1*1.0j + omega2*1.0j - 2.0, (mu - vT)/mp.sqrt(D)) - mp.exp(delta) * mp.pcfd(omega1*1.0j + omega2*1.0j - 2.0, (mu - vR)/mp.sqrt(D))
b2 = mp.pcfd(omega1*1.0j + omega2*1.0j - 1.0, (mu - vT)/mp.sqrt(D))
b3 = mp.exp(delta) * mp.pcfd(omega1*1.0j + omega2*1.0j - 1.0, (mu - vR)/mp.sqrt(D))
c = mp.pcfd(omega1*1.0j + omega2*1.0j, (mu - vT)/mp.sqrt(D)) - mp.exp(delta) * mp.exp(1.0j*(omega1 + omega2)*tau_ref) * mp.pcfd(omega1*1.0j + omega2*1.0j, (mu - vR)/mp.sqrt(D))
return a1 * b1/c + a2*a3*b2/c - a2*a3*b3/c
def susceptibilities(frange1, frange2, mu, D, tau_ref, vR, vT):
print(f'compute LIF susceptibilites for mu={mu:4.1f} and D={D:g}:')
print(f' mean firing rate ...')
r0 = firingrate(mu, D, tau_ref, vR, vT)
# chi1:
print(f' chi1 ...')
chi1_data = np.zeros(len(frange1), dtype=complex)
for f2 in range(len(frange1)):
omega2 = 2.0*np.pi*frange1[f2]
chi1_2 = susceptibility1(omega2, r0, mu, D, tau_ref, vR, vT)
chi1_data[f2] = chi1_2
# chi2:
chi2_data = np.zeros((len(frange2), len(frange2)), dtype=complex)
for f2 in range(len(frange2)):
print(f' chi2 step {f2 + 1:4d} of {len(frange2):4d}')
omega2 = 2.0*np.pi*frange2[f2]
chi1_2 = susceptibility1(omega2, r0, mu, D, tau_ref, vR, vT)
for f1 in range(len(frange2)):
omega1 = 2.0*np.pi*frange2[f1]
chi1_1 = susceptibility1(omega1, r0, mu, D, tau_ref, vR, vT)
chi2 = susceptibility2(omega1, omega2, chi1_1, chi1_2, r0, mu, D, tau_ref, vR, vT)
chi2_data[f2, f1] = chi2
return r0, chi1_data, chi2_data
def load_lifdata(mu, D, vT=1, vR=0, tau_ref=0):
file_path = sims_path / f'lif-mu{10*mu:03.0f}-D{10000*D:04.0f}-chi2.npz'
if not file_path.exists():
freqs1 = np.linspace(0.01, 1.0, 500)
freqs2 = np.linspace(0.01, 1.0, 200)
r0, chi1, chi2 = susceptibilities(freqs1, freqs2, mu, D,
tau_ref, vR, vT)
np.savez(file_path, mu=mu, D=D, vT=vT, vR=vR,
tau_mem=1, tau_ref=tau_ref, r0=r0,
freqs1=freqs1, chi1=chi1, freqs2=freqs2, chi2=chi2)
data = np.load(file_path)
r0 = float(data['r0'])
freqs1 = data['freqs1']
chi1 = data['chi1']
freqs2 = data['freqs2']
chi2 = data['chi2']
print(f'LIF with mu={mu:4.1f} and D={D:g}')
return r0, freqs1, chi1, freqs2, chi2
def plot_gain(ax, s, r0, freqs, chi1):
ax.axvline(r0, **s.lsGrid)
ax.axvline(2*r0, **s.lsGrid)
ax.plot(freqs, np.abs(chi1), **s.lsM1)
ax.set_xlabel('$f$')
ax.set_ylabel('$|\\chi_1(f)|$', labelpad=6)
ax.set_xlim(0, 1)
ax.set_ylim(0, 14)
ax.set_xticks_delta(0.2)
ax.set_yticks_delta(3)
ax.text(r0, 14.2, '$r$', ha='center')
ax.text(2*r0, 14.2, '$2r$', ha='center')
def plot_chi2(ax, s, r0, freqs, chi2):
chi2 = np.abs(chi2)
vmax = np.quantile(chi2, 0.996)
vmax = 300
pc = ax.pcolormesh(freqs, freqs, chi2, vmin=0, vmax=vmax,
rasterized=True)
ax.set_aspect('equal')
ax.set_xlabel('$f_1$')
ax.set_ylabel('$f_2$', labelpad=6)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xticks_delta(0.2)
ax.set_yticks_delta(0.2)
cax = ax.inset_axes([1.04, 0, 0.05, 1])
cax.set_spines_outward('lrbt', 0)
cb = fig.colorbar(pc, cax=cax)
cb.outline.set_color('none')
cb.outline.set_linewidth(0)
cax.set_ylabel('$|\\chi_2(f_1, f_2)|$')
cax.set_yticks_delta(100)
if __name__ == '__main__':
mu = 1.1
D = 0.001
r0, freqs1, chi1, freqs2, chi2 = load_lifdata(mu, D)
s = plot_style()
plt.rcParams['axes.labelpad'] = 2
fig, (axg, axc) = plt.subplots(1, 2,
cmsize=(s.plot_width, 0.38*s.plot_width))
fig.subplots_adjust(leftm=8, rightm=8.5, topm=1.5, bottomm=3.5, wspace=0.4)
fig.set_align(autox=False)
plot_gain(axg, s, r0, freqs1, chi1)
plot_chi2(axc, s, r0, freqs2, chi2)
fig.tag()
fig.savefig()
print()