nonlinearbaseline2025/regimes.py

427 lines
16 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.stats import linregress
from numba import jit
from spectral import rate
from plotstyle import plot_style, lighter, darker
model_cell = '2018-05-08-ad-invivo-1' # 228Hz, CV=0.67
alphas = [0.002, 0.01, 0.03, 0.06]
rmax = 500
amax = 60
cthresh1 = 1.2
cthresh2 = 3.5
model_cell = '2018-05-08-ab-invivo-1' # 116, CV=0.68
alphas = [0.002, 0.008, 0.025, 0.05]
rmax = 400
amax = 50
cthresh1 = 1.2
cthresh2 = 3.0
data_path = Path('data')
sims_path = data_path / 'simulations'
trials = 1000
spec_trials = 100 # set to zero to only recompute firng rates
sigma = 0.002
nfft = 2**18
recompute = False
def load_data(file_path):
data = np.load(file_path)
ratebase = float(data['ratebase'])
cvbase = float(data['cvbase'])
beatf1 = float(data['beatf1'])
beatf2 = float(data['beatf2'])
contrasts = data['contrasts']
powerf1 = data['powerf1']
powerf2 = data['powerf2']
powerfsum = data['powerfsum']
powerfdiff = data['powerfdiff']
return (ratebase, cvbase, beatf1, beatf2,
contrasts, powerf1, powerf2, powerfsum, powerfdiff)
def load_models(file):
""" Load model parameter from csv file.
Parameters
----------
file: string
Name of file with model parameters.
Returns
-------
parameters: list of dict
For each cell a dictionary with model parameters.
"""
parameters = []
with open(file, 'r') as file:
header_line = file.readline()
header_parts = header_line.strip().split(",")
keys = header_parts
for line in file:
line_parts = line.strip().split(",")
parameter = {}
for i in range(len(keys)):
parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i]
parameters.append(parameter)
return parameters
def cell_parameters(parameters, cell_name):
for params in parameters:
if params['cell'] == cell_name:
return params
print('cell', cell_name, 'not found!')
exit()
return None
@jit(nopython=True)
def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0,
threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1,
v_offset=-10.0, mem_tau=0.015, noise_strength=0.05,
input_scaling=60.0, dend_tau=0.001, ref_period=0.001):
""" Simulate a P-unit.
Returns
-------
spike_times: 1-D array
Simulated spike times in seconds.
"""
# initial conditions:
v_dend = stimulus[0]
v_mem = v_zero
adapt = a_zero
# prepare noise:
noise = np.random.randn(len(stimulus))
noise *= noise_strength / np.sqrt(deltat)
# rectify stimulus array:
stimulus = stimulus.copy()
stimulus[stimulus < 0.0] = 0.0
# integrate:
spike_times = []
for i in range(len(stimulus)):
v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat
v_mem += (v_base - v_mem + v_offset + (
v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat
adapt += -adapt / tau_a * deltat
# refractory period:
if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2:
v_mem = v_base
# threshold crossing:
if v_mem > threshold:
v_mem = v_base
spike_times.append(i * deltat)
adapt += delta_a / tau_a
return np.array(spike_times)
def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials):
tini = 0.2
model_params = dict(parameter)
cell = model_params.pop('cell')
eodf0 = model_params.pop('EODf')
time = np.arange(-tini, tmax, model_params['deltat'])
stimulus = np.sin(2*np.pi*eodf0*time)
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time)
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time)
spikes = []
for i in range(trials):
model_params['v_zero'] = np.random.rand()
model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn()
spiket = simulate(stimulus, **model_params)
spikes.append(spiket[spiket > tini] - tini)
return spikes
def plot_am(ax, s, alpha, beatf1, beatf2, tmin, tmax):
time = np.arange(tmin, tmax, 0.0001)
am = alpha*np.sin(2*np.pi*beatf1*time)
am += alpha*np.sin(2*np.pi*beatf2*time)
ax.show_spines('l')
ax.plot(1000*(time - tmin), -100*am, **s.lsAM)
ax.set_xlim(0, 1000*(tmax - tmin))
ax.set_ylim(-13, 13)
ax.set_yticks_delta(10)
#ax.set_xlabel('Time', 'ms')
ax.set_ylabel('AM', r'\%')
ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%',
transform=ax.transAxes, ha='right')
def plot_raster(ax, s, spikes, tmin, tmax):
spikes_ms = [1000*(s[(s >= tmin) & (s <= tmax)] - tmin)
for s in spikes[:16]]
ax.show_spines('')
ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
ax.set_xlim(0, 1000*(tmax - tmin))
#ax.set_xlabel('Time', 'ms')
#ax.set_ylabel('Trials')
def plot_rate(ax, s, path, spikes, tmin, tmax, sigma=0.002):
if recompute or not path.is_file():
print(' compute firing rate')
time = np.arange(0, tmin + tmax, sigma/4)
r, rsd = rate(time, spikes, sigma)
np.savez(path, time=time, rate=r, ratesd=rsd,
sigma=sigma, trials=len(spikes))
else:
print(f' load firing rate from {path}')
data = np.load(path)
time = data['time']
r = data['rate']
rsd = data['ratesd']
mask = (time >= tmin) &(time <= tmax)
time = time[mask] - tmin
r = r[mask]
ax.show_spines('l')
ax.plot(1000*time, r, clip_on=False, **s.lsRate)
ax.set_xlim(0, 1000*(tmax - tmin))
ax.set_ylim(0, rmax)
ax.set_ylabel('Rate', 'Hz')
ax.set_yticks_delta(200)
def compute_power(path, spikes, nfft, dt):
if spec_trials > 0 and (recompute or not path.is_file()):
print(' compute power spectrum')
psds = []
time = np.arange(nfft + 1)*dt
tmax = nfft*dt
for s in spikes[:spec_trials]:
b, _ = np.histogram(s, time)
b = b / dt
fourier = np.fft.rfft(b - np.mean(b))
psds.append(np.abs(fourier)**2)
freqs = np.fft.rfftfreq(nfft, dt)
prr = np.mean(psds, 0)*dt/nfft
np.savez(path, nfft=nfft, deltat=dt, nsegs=len(spikes),
freqs=freqs, prr=prr, trials=len(spikes))
else:
print(f' load power spectrum from {path}')
data = np.load(path)
freqs = data['freqs']
prr = data['prr']
return freqs, prr
def decibel(x):
return 10*np.log10(x/1e8 + 1e-12)
def peak_ampl(freqs, psd, f, df=2):
if f < 0:
f = 5
psd_snippet = psd[(freqs > f - df) & (freqs < f + df)]
return np.max(psd_snippet)
def plot_psd(ax, s, path, contrast, spikes, nfft, dt, beatf1, beatf2, eodf):
offs = 5
offsm = 3
freqs, psd = compute_power(path, spikes, nfft, dt)
psd /= freqs[1]
ax.plot(freqs, decibel(psd), **s.lsPower)
# mark frequencies:
ax.plot(eodf, decibel(peak_ampl(freqs, psd, eodf)) + offs,
label=r'$f_{EOD}$', clip_on=False, zorder=50, **s.psFEOD)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs,
label=r'$r$', clip_on=False, zorder=50, **s.psF0)
ax.plot(beatf1, decibel(peak_ampl(freqs, psd, beatf1)) + offs,
label=r'$\Delta f_1$', clip_on=False, zorder=50, **s.psF01)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + 2*offs + 2,
label=r'$\Delta f_2$', clip_on=False, zorder=50, **s.psF02)
ax.plot(beatf2 - beatf1, decibel(peak_ampl(freqs, psd, beatf2 - beatf1)) + offs,
label=r'$\Delta f_2 - \Delta f_1$', clip_on=False, zorder=50, **s.psF01_2)
ax.plot(beatf1 + beatf2, decibel(peak_ampl(freqs, psd, beatf1 + beatf2)) + offs,
label=r'$\Delta f_1 + \Delta f_2$', clip_on=False, zorder=50, **s.psF012)
ax.plot(eodf + beatf1, decibel(peak_ampl(freqs, psd, eodf + beatf1)) + offsm,
label=r'$f_{EOD} \pm k \Delta f_1$', zorder=40, **s.psFEODm)
ax.plot(eodf - beatf1, decibel(peak_ampl(freqs, psd, eodf - beatf1)) + offsm, **s.psFEODm)
if contrast >= alphas[1]:
ax.plot(eodf - beatf2, decibel(peak_ampl(freqs, psd, eodf - beatf2)) + offsm,
label=r'$f_{EOD} - k \Delta f_2$', zorder=40, **s.psF0m)
if contrast >= alphas[2]:
ax.plot(2*beatf1, decibel(peak_ampl(freqs, psd, 2*beatf1)) + offsm,
label=r'$k\Delta f_1$', zorder=40, **s.psF01m)
ax.plot(eodf + 2*beatf1, decibel(peak_ampl(freqs, psd, eodf + 2*beatf1)) + offsm, zorder=40, **s.psFEODm)
ax.plot(eodf - beatf2 + beatf1, decibel(peak_ampl(freqs, psd, eodf - beatf2 + beatf1)) + offsm,
label=r'$f_{EOD} - \Delta f_2 \pm k\Delta f_1$', zorder=40, **s.psF02m)
ax.plot(eodf - beatf2 - beatf1, decibel(peak_ampl(freqs, psd, eodf - beatf2 - beatf1)) + offsm, zorder=40, **s.psF02m)
if contrast >= alphas[3]:
ax.plot(beatf2 + 2*beatf1, decibel(peak_ampl(freqs, psd, beatf2 + 2*beatf1)) + offsm,
label=r'$\Delta f_2 \pm k\Delta f_1$', zorder=40, **s.psF012m)
ax.plot(beatf2 + 3*beatf1, decibel(peak_ampl(freqs, psd, beatf2 + 3*beatf1)) + offsm, zorder=40, **s.psF012m)
ax.plot(beatf2 - 2*beatf1, decibel(peak_ampl(freqs, psd, beatf2 - 2*beatf1)) + offsm, zorder=40, **s.psF012m)
ax.plot(beatf2 - 3*beatf1, decibel(peak_ampl(freqs, psd, beatf2 - 3*beatf1)) + offsm, zorder=40, **s.psF012m)
ax.plot(3*beatf1, decibel(peak_ampl(freqs, psd, 3*beatf1)) + offsm, zorder=40, **s.psF01m)
ax.plot(4*beatf1, decibel(peak_ampl(freqs, psd, 4*beatf1)) + offsm, zorder=40, **s.psF01m)
ax.plot(eodf - 2*beatf1, decibel(peak_ampl(freqs, psd, eodf - 2*beatf1)) + offsm, zorder=40, **s.psFEODm)
ax.plot(eodf - 3*beatf1, decibel(peak_ampl(freqs, psd, eodf - 3*beatf1)) + offsm, zorder=40, **s.psFEODm)
ax.plot(eodf - 4*beatf1, decibel(peak_ampl(freqs, psd, eodf - 4*beatf1)) + offsm, zorder=40, **s.psFEODm)
ax.plot(eodf - 2*beatf2, decibel(peak_ampl(freqs, psd, eodf - 2*beatf2)) + offsm, zorder=40, **s.psF0m)
ax.plot(eodf - beatf2 + 2*beatf1, decibel(peak_ampl(freqs, psd, eodf - beatf2 + 2*beatf1)) + offsm,
zorder=40, **s.psF02m)
ax.plot(eodf - beatf2 + 3*beatf1, decibel(peak_ampl(freqs, psd, eodf - beatf2 + 2*beatf1)) + offsm,
zorder=40, **s.psF02m)
ax.set_xlim(0, 750)
ax.set_ylim(-60, 0)
ax.set_xticks_delta(200)
ax.set_yticks_delta(20)
ax.set_xlabel('Frequency', 'Hz')
ax.set_ylabel('Power [dB]')
def plot_example(axs, axr, axf, axp, s, path, cell, alpha, beatf1, beatf2, eodf,
nfft, trials):
spec_path = path.with_name(path.stem + f'-contrastspectrum-{1000*alpha:03.0f}.npz')
rate_path = path.with_name(path.stem + f'-contrastrates-{1000*alpha:03.0f}.npz')
dt = 0.0001
tmax = nfft*dt
t0 = 0.112
t1 = 0.212
if not recompute and spec_path.is_file() and rate_path.is_file():
tmax = t0 + t1
trials = 20
else:
print(' compute spike response')
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
plot_am(axs, s, alpha, beatf1, beatf2, t0, t1)
plot_raster(axr, s, spikes, t0, t1)
plot_rate(axf, s, rate_path, spikes, t0, t1, sigma)
plot_psd(axp, s, spec_path, alpha, spikes, nfft, dt, beatf1, beatf2, eodf)
def amplitude(power):
power -= power[0]
power[power<0] = 0
return np.sqrt(power)
def amplitude_linearfit(contrast, power, max_contrast):
power -= power[0]
power[power<0] = 0
ampl = np.sqrt(power)
a = ampl[contrast <= max_contrast]
c = contrast[contrast <= max_contrast]
r = linregress(c, a)
return r.intercept + r.slope*contrast
def amplitude_squarefit(contrast, power, max_contrast):
power -= power[0]
power[power<0] = 0
ampl = np.sqrt(power)
a = np.sqrt(ampl[contrast <= max_contrast])
c = contrast[contrast <= max_contrast]
r = linregress(c, a)
return (r.intercept + r.slope*contrast)**2
def plot_peaks(ax, s, alphas, contrasts, powerf1, powerf2, powerfsum,
powerfdiff):
cmax = 10
contrasts *= 100
ax.plot(contrasts, amplitude_linearfit(contrasts, powerf1, cthresh1),
**s.lsF01m)
ax.plot(contrasts, amplitude_linearfit(contrasts, powerf2, cthresh1),
**s.lsF02m)
ax.plot(contrasts, amplitude_squarefit(contrasts, powerfsum, cthresh2),
**s.lsF012m)
ax.plot(contrasts, amplitude_squarefit(contrasts, powerfdiff, cthresh2),
**s.lsF01_2m)
ax.plot(contrasts, amplitude(powerf1), **s.lsF01)
ax.plot(contrasts, amplitude(powerf2), **s.lsF02)
mask = contrasts < cmax
ax.plot(contrasts[mask], amplitude(powerfsum)[mask],
clip_on=False, **s.lsF012)
ax.plot(contrasts[mask], amplitude(powerfdiff)[mask],
clip_on=False, **s.lsF01_2)
for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']):
ax.plot(100*alpha, 1.05*amax, 'vk', ms=4, clip_on=False)
ax.text(100*alpha, 1.1*amax, tag, ha='center')
#ax.axvline(contrast, **s.lsGrid)
#ax.text(contrast, 630, tag, ha='center')
print(f'Linear regime ends at a contrast of about {cthresh1:4.1f}%')
print(f'Weakly non-linear regime ends at a contrast of about {cthresh2:4.1f}%')
ax.axvline(cthresh1, **s.lsLine)
ax.axvline(cthresh2, **s.lsLine)
yoffs = 35 if amax == 60 else 31
ax.text(cthresh1/2, yoffs, 'linear\nregime',
ha='center', va='center')
ax.text((cthresh1 + cthresh2)/2, yoffs, 'weakly\nnonlinear\nregime',
ha='center', va='center')
if amax == 60:
ax.text(5.5, yoffs, 'strongly\nnonlinear\nregime',
ha='center', va='center')
else:
ax.text(5.5, 6, 'strongly\nnonlinear\nregime',
ha='center', va='bottom')
ax.set_xlim(0, cmax)
ax.set_ylim(0, amax)
ax.set_xticks_delta(2)
ax.set_yticks_delta(20)
ax.set_xlabel('Contrast', r'\%')
ax.set_ylabel('Amplitude', 'Hz')
if __name__ == '__main__':
ratebase, cvbase, beatf1, beatf2, \
contrasts, powerf1, powerf2, powerfsum, powerfdiff = \
load_data(sims_path / f'{model_cell}-contrastpeaks.npz')
parameters = load_models(data_path / 'punitmodels.csv')
cell = cell_parameters(parameters, model_cell)
eodf = cell['EODf']
print(f'Loaded data for cell {model_cell}: ')
print(f' baseline rate = {ratebase:.0f}Hz, CV = {cvbase:.2f}')
print(f' f1 = {beatf1:.0f}Hz, f2 = {beatf2:.0f}Hz')
print()
s = plot_style()
fig, (axes, axa) = plt.subplots(2, 1, height_ratios=[5, 3],
cmsize=(s.plot_width, 0.7*s.plot_width))
fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5, hspace=0.6)
axe = axes.subplots(4, 4, wspace=0.4, hspace=0.2,
height_ratios=[1, 2, 2, 0.6, 3])
fig.show_spines('lb')
# example power spectra:
for c, alpha in enumerate(alphas):
path = sims_path / f'{model_cell}'
print(f'Example response for contrast {100*alpha:4.1f}%:')
plot_example(axe[0, c], axe[1, c], axe[2, c], axe[3, c], s, path,
cell, alpha, beatf1, beatf2, eodf, nfft, trials)
print()
axe[2, 0].xscalebar(1, -0.1, 20, 'ms', ha='right')
axe[3, 3].legend(loc='center right', bbox_to_anchor=(1.05, -0.9),
ncol=11, columnspacing=0.6, handletextpad=0)
fig.common_yspines(axe[0, :])
fig.common_yticks(axe[2, :])
fig.common_yticks(axe[3, :])
fig.tag(axe[0, :], xoffs=-3, yoffs=1.6)
# contrast dependence:
plot_peaks(axa, s, alphas, contrasts, powerf1, powerf2,
powerfsum, powerfdiff)
fig.tag(axa, yoffs=2)
fig.savefig()
print()