nonlinearbaseline2025/regimes.py
2025-05-19 10:09:01 +02:00

325 lines
11 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.stats import linregress
from numba import jit
from thunderlab.tabledata import TableData
from plotstyle import plot_style, lighter, darker
data_path = Path('data')
sims_path = data_path / 'simulations'
def load_data(file_path):
data = np.load(file_path)
ratebase = float(data['ratebase'])
cvbase = float(data['cvbase'])
beatf1 = float(data['beatf1'])
beatf2 = float(data['beatf2'])
contrasts = data['contrasts']
powerf1 = data['powerf1']
powerf2 = data['powerf2']
powerfsum = data['powerfsum']
powerfdiff = data['powerfdiff']
return (ratebase, cvbase, beatf1, beatf2,
contrasts, powerf1, powerf2, powerfsum, powerfdiff)
def load_models(file):
""" Load model parameter from csv file.
Parameters
----------
file: string
Name of file with model parameters.
Returns
-------
parameters: list of dict
For each cell a dictionary with model parameters.
"""
parameters = []
with open(file, 'r') as file:
header_line = file.readline()
header_parts = header_line.strip().split(",")
keys = header_parts
for line in file:
line_parts = line.strip().split(",")
parameter = {}
for i in range(len(keys)):
parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i]
parameters.append(parameter)
return parameters
def cell_parameters(parameters, cell_name):
for params in parameters:
if params['cell'] == cell_name:
return params
print('cell', cell_name, 'not found!')
exit()
return None
@jit(nopython=True)
def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0,
threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1,
v_offset=-10.0, mem_tau=0.015, noise_strength=0.05,
input_scaling=60.0, dend_tau=0.001, ref_period=0.001):
""" Simulate a P-unit.
Returns
-------
spike_times: 1-D array
Simulated spike times in seconds.
"""
# initial conditions:
v_dend = stimulus[0]
v_mem = v_zero
adapt = a_zero
# prepare noise:
noise = np.random.randn(len(stimulus))
noise *= noise_strength / np.sqrt(deltat)
# rectify stimulus array:
stimulus = stimulus.copy()
stimulus[stimulus < 0.0] = 0.0
# integrate:
spike_times = []
for i in range(len(stimulus)):
v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat
v_mem += (v_base - v_mem + v_offset + (
v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat
adapt += -adapt / tau_a * deltat
# refractory period:
if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2:
v_mem = v_base
# threshold crossing:
if v_mem > threshold:
v_mem = v_base
spike_times.append(i * deltat)
adapt += delta_a / tau_a
return np.array(spike_times)
def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials):
tini = 0.2
model_params = dict(parameter)
cell = model_params.pop('cell')
eodf0 = model_params.pop('EODf')
time = np.arange(-tini, tmax, model_params['deltat'])
stimulus = np.sin(2*np.pi*eodf0*time)
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time)
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time)
spikes = []
for i in range(trials):
model_params['v_zero'] = np.random.rand()
model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn()
spiket = simulate(stimulus, **model_params)
spikes.append(spiket[spiket > tini] - tini)
return spikes
def plot_am(ax, s, alpha, beatf1, beatf2, tmax):
time = np.arange(0, tmax, 0.0001)
am = alpha*np.sin(2*np.pi*beatf1*time)
am += alpha*np.sin(2*np.pi*beatf2*time)
ax.show_spines('l')
ax.plot(1000*time, -100*am, **s.lsAM)
ax.set_xlim(0, 1000*tmax)
ax.set_ylim(-13, 13)
ax.set_yticks_delta(10)
#ax.set_xlabel('Time', 'ms')
ax.set_ylabel('AM', r'\%')
ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%',
transform=ax.transAxes, ha='right')
def plot_raster(ax, s, spikes, tmax):
spikes_ms = [1000*s[s<tmax] for s in spikes[:16]]
ax.show_spines('')
ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
ax.set_xlim(0, 1000*tmax)
#ax.set_xlabel('Time', 'ms')
#ax.set_ylabel('Trials')
def compute_power(path, contrast, spikes, nfft, dt):
if not path.exists():
print(f' Compute power spectrum for contrast = {100*contrast:4.1f}%')
psds = []
time = np.arange(nfft + 1)*dt
tmax = nfft*dt
for s in spikes:
b, _ = np.histogram(s, time)
b = b / dt
fourier = np.fft.rfft(b - np.mean(b))
psds.append(np.abs(fourier)**2)
freqs = np.fft.rfftfreq(nfft, dt)
prr = np.mean(psds, 0)*dt/nfft
np.savez(path, nfft=nfft, deltat=dt, nsegs=len(spikes),
freqs=freqs, prr=prr)
else:
print(f' Load power spectrum for contrast = {100*contrast:4.1f}%')
data = np.load(path)
freqs = data['freqs']
prr = data['prr']
return freqs, prr
def decibel(x):
return 10*np.log10(x/1e8)
def plot_psd(ax, s, path, contrast, spikes, nfft, dt, beatf1, beatf2):
offs = 4
freqs, psd = compute_power(path, contrast, spikes, nfft, dt)
psd /= freqs[1]
ax.plot(freqs, decibel(psd), **s.lsPower)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs,
label=r'$r$', clip_on=False, **s.psF0)
ax.plot(beatf1, decibel(peak_ampl(freqs, psd, beatf1)) + offs,
label=r'$\Delta f_1$', clip_on=False, **s.psF01)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs + 5.5,
label=r'$\Delta f_2$', clip_on=False, **s.psF02)
ax.plot(beatf2 - beatf1, decibel(peak_ampl(freqs, psd, beatf2 - beatf1)) + offs,
label=r'$\Delta f_2 - \Delta f_1$', clip_on=False, **s.psF01_2)
ax.plot(beatf1 + beatf2, decibel(peak_ampl(freqs, psd, beatf1 + beatf2)) + offs,
label=r'$\Delta f_1 + \Delta f_2$', clip_on=False, **s.psF012)
ax.set_xlim(0, 300)
ax.set_ylim(-60, 0)
ax.set_xlabel('Frequency', 'Hz')
ax.set_ylabel('Power [dB]')
def plot_example(axs, axr, axp, s, path, cell, alpha, beatf1, beatf2,
nfft, trials):
sim_path = path / f'{cell_name}-contrastspectrum-{1000*alpha:03.0f}.npz'
dt = 0.0001
tmax = nfft*dt
t1 = 0.1
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
plot_am(axs, s, alpha, beatf1, beatf2, t1)
plot_raster(axr, s, spikes, t1)
plot_psd(axp, s, sim_path, alpha, spikes, nfft, dt, beatf1, beatf2)
def peak_ampl(freqs, psd, f):
df = 2
psd_snippet = psd[(freqs > f - df) & (freqs < f + df)]
return np.max(psd_snippet)
def amplitude(power):
power -= power[0]
power[power<0] = 0
return np.sqrt(power)
def amplitude_linearfit(contrast, power, max_contrast):
power -= power[0]
power[power<0] = 0
ampl = np.sqrt(power)
a = ampl[contrast <= max_contrast]
c = contrast[contrast <= max_contrast]
r = linregress(c, a)
return r.intercept + r.slope*contrast
def amplitude_squarefit(contrast, power, max_contrast):
power -= power[0]
power[power<0] = 0
ampl = np.sqrt(power)
a = np.sqrt(ampl[contrast <= max_contrast])
c = contrast[contrast <= max_contrast]
r = linregress(c, a)
return (r.intercept + r.slope*contrast)**2
def plot_peaks(ax, s, alphas, contrasts, powerf1, powerf2, powerfsum,
powerfdiff):
cmax = 10
contrasts *= 100
ax.plot(contrasts, amplitude_linearfit(contrasts, powerf1, 4),
**s.lsF01m)
ax.plot(contrasts, amplitude_linearfit(contrasts, powerf2, 2),
**s.lsF02m)
ax.plot(contrasts, amplitude_squarefit(contrasts, powerfsum, 4),
**s.lsF012m)
ax.plot(contrasts, amplitude_squarefit(contrasts, powerfdiff, 4),
**s.lsF01_2m)
ax.plot(contrasts, amplitude(powerf1), **s.lsF01)
ax.plot(contrasts, amplitude(powerf2), **s.lsF02)
mask = contrasts < cmax
ax.plot(contrasts[mask], amplitude(powerfsum)[mask],
clip_on=False, **s.lsF012)
ax.plot(contrasts[mask], amplitude(powerfdiff)[mask],
clip_on=False, **s.lsF01_2)
ymax = 60
for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']):
ax.plot(100*alpha, ymax*0.95, 'vk', ms=4, clip_on=False)
ax.text(100*alpha, ymax, tag, ha='center')
#ax.axvline(contrast, **s.lsGrid)
#ax.text(contrast, 630, tag, ha='center')
ax.axvline(1.2, **s.lsLine)
ax.axvline(3.5, **s.lsLine)
yoffs = 35
ax.text(1.2/2, yoffs, 'linear\nregime',
ha='center', va='center')
ax.text((1.2 + 3.5)/2, yoffs, 'weakly\nnonlinear\nregime',
ha='center', va='center')
ax.text(5.5, yoffs, 'strongly\nnonlinear\nregime',
ha='center', va='center')
ax.set_xlim(0, cmax)
ax.set_ylim(0, ymax)
ax.set_xticks_delta(2)
ax.set_yticks_delta(20)
ax.set_xlabel('Contrast', r'\%')
ax.set_ylabel('Amplitude', 'Hz')
if __name__ == '__main__':
cell_name = '2018-05-08-ad-invivo-1' # 228Hz, CV=0.67
ratebase, cvbase, beatf1, beatf2, \
contrasts, powerf1, powerf2, powerfsum, powerfdiff = \
load_data(sims_path / f'{cell_name}-contrastpeaks.npz')
alphas = [0.002, 0.01, 0.03, 0.06]
parameters = load_models(data_path / 'punitmodels.csv')
cell = cell_parameters(parameters, cell_name)
nfft = 2**18
print(f'Loaded data for cell {cell_name}: '
f'baseline rate = {ratebase:.0f}Hz, CV = {cvbase:.2f}')
s = plot_style()
fig, (axes, axa) = plt.subplots(2, 1, height_ratios=[4, 3],
cmsize=(s.plot_width, 0.6*s.plot_width))
fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5, hspace=0.6)
axe = axes.subplots(3, 4, wspace=0.4, hspace=0.2,
height_ratios=[1, 2, 3])
fig.show_spines('lb')
# example power spectra:
for c, alpha in enumerate(alphas):
plot_example(axe[0, c], axe[1, c], axe[2, c], s, sims_path,
cell, alpha, beatf1, beatf2, nfft, 100)
axe[1, 0].xscalebar(1, -0.1, 20, 'ms', ha='right')
axe[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8),
ncol=5, columnspacing=2)
fig.common_yspines(axe[0, :])
fig.common_yticks(axe[2, :])
fig.tag(axe[0, :], xoffs=-3, yoffs=1.6)
# contrast dependence:
plot_peaks(axa, s, alphas, contrasts, powerf1, powerf2,
powerfsum, powerfdiff)
fig.tag(axa, yoffs=2)
fig.savefig()