nonlinearbaseline2025/regimes.py
2025-05-16 09:22:34 +02:00

348 lines
12 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.stats import linregress
from numba import jit
from thunderlab.tabledata import TableData
from plotstyle import plot_style, lighter, darker
data_path = Path('data')
cell_path = data_path / 'cells'
def load_models(file):
""" Load model parameter from csv file.
Parameters
----------
file: string
Name of file with model parameters.
Returns
-------
parameters: list of dict
For each cell a dictionary with model parameters.
"""
parameters = []
with file.open('r') as file:
header_line = file.readline()
header_parts = header_line.strip().split(",")
keys = header_parts
for line in file:
line_parts = line.strip().split(",")
parameter = {}
for i in range(len(keys)):
parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i]
parameters.append(parameter)
return parameters
def cell_parameters(parameters, cell_name):
for params in parameters:
if params['cell'] == cell_name:
return params
print('cell', cell_name, 'not found!')
exit()
return None
@jit(nopython=True)
def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0,
threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1,
v_offset=-10.0, mem_tau=0.015, noise_strength=0.05,
input_scaling=60.0, dend_tau=0.001, ref_period=0.001):
""" Simulate a P-unit.
Returns
-------
spike_times: 1-D array
Simulated spike times in seconds.
"""
# initial conditions:
v_dend = stimulus[0]
v_mem = v_zero
adapt = a_zero
# prepare noise:
noise = np.random.randn(len(stimulus))
noise *= noise_strength / np.sqrt(deltat)
# rectify stimulus array:
stimulus = stimulus.copy()
stimulus[stimulus < 0.0] = 0.0
# integrate:
spike_times = []
for i in range(len(stimulus)):
v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat
v_mem += (v_base - v_mem + v_offset + (
v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat
adapt += -adapt / tau_a * deltat
# refractory period:
if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2:
v_mem = v_base
# threshold crossing:
if v_mem > threshold:
v_mem = v_base
spike_times.append(i * deltat)
adapt += delta_a / tau_a
return np.array(spike_times)
def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials):
tini = 0.2
model_params = dict(parameter)
cell = model_params.pop('cell')
eodf0 = model_params.pop('EODf')
time = np.arange(-tini, tmax, model_params['deltat'])
stimulus = np.sin(2*np.pi*eodf0*time)
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time)
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time)
spikes = []
for i in range(trials):
model_params['v_zero'] = np.random.rand()
model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn()
spiket = simulate(stimulus, **model_params)
spikes.append(spiket[spiket > tini] - tini)
return spikes
def plot_am(ax, s, alpha, beatf1, beatf2, tmax):
time = np.arange(0, tmax, 0.0001)
am = alpha*np.sin(2*np.pi*beatf1*time)
am += alpha*np.sin(2*np.pi*beatf2*time)
ax.show_spines('l')
ax.plot(1000*time, -100*am, **s.lsStim)
ax.set_xlim(0, 1000*tmax)
ax.set_ylim(-50, 50)
#ax.set_xlabel('Time', 'ms')
ax.set_ylabel('AM', r'\%')
ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%',
transform=ax.transAxes, ha='right')
def plot_raster(ax, s, spikes, tmax):
spikes_ms = [1000*s[s<tmax] for s in spikes[:16]]
ax.show_spines('')
ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
ax.set_xlim(0, 1000*tmax)
#ax.set_xlabel('Time', 'ms')
#ax.set_ylabel('Trials')
def compute_power(spikes, nfft, dt):
psds = []
time = np.arange(nfft + 1)*dt
tmax = nfft*dt
rates = []
cvs = []
for s in spikes:
rates.append(len(s)/tmax)
isis = np.diff(s)
cvs.append(np.std(isis)/np.mean(isis))
b, _ = np.histogram(s, time)
fourier = np.fft.rfft(b - np.mean(b))
psds.append(np.abs(fourier)**2)
#psds.append(fourier)
freqs = np.fft.rfftfreq(nfft, dt)
#print('mean rate', np.mean(rates))
#print('CV', np.mean(cvs))
return freqs, np.mean(psds, 0)
#return freqs, np.abs(np.mean(psds, 0))**2/dt
def decibel(x):
return 10*np.log10(x/1e8)
def plot_psd(ax, s, spikes, nfft, dt, beatf1, beatf2):
offs = 3
freqs, psd = compute_power(spikes, nfft, dt)
psd /= freqs[1]
ax.plot(freqs, decibel(psd), **s.lsPower)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs,
label=r'$f_{\rm base}$', clip_on=False, **s.psF0)
ax.plot(beatf1, decibel(peak_ampl(freqs, psd, beatf1)) + offs,
label=r'$\Delta f_1$', clip_on=False, **s.psF01)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs + 4.5,
label=r'$\Delta f_2$', clip_on=False, **s.psF02)
ax.plot(beatf2 - beatf1, decibel(peak_ampl(freqs, psd, beatf2 - beatf1)) + offs,
label=r'$\Delta f_2 - \Delta f_1$', clip_on=False, **s.psF01_2)
ax.plot(beatf1 + beatf2, decibel(peak_ampl(freqs, psd, beatf1 + beatf2)) + offs,
label=r'$\Delta f_1 + \Delta f_2$', clip_on=False, **s.psF012)
ax.set_xlim(0, 300)
ax.set_ylim(-40, 0)
ax.set_xlabel('Frequency', 'Hz')
ax.set_ylabel('Power [dB]')
def plot_example(axs, axr, axp, s, cell, alpha, beatf1, beatf2, nfft, trials):
dt = 0.0001
tmax = nfft*dt
t1 = 0.1
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
plot_am(axs, s, alpha, beatf1, beatf2, t1)
plot_raster(axr, s, spikes, t1)
plot_psd(axp, s, spikes, nfft, dt, beatf1, beatf2)
def peak_ampl(freqs, psd, f):
df = 2
psd_snippet = psd[(freqs > f - df) & (freqs < f + df)]
return np.max(psd_snippet)
def compute_peaks(name, cell, alpha_max, beatf1, beatf2, nfft, trials):
data_file = cell_path / f'{name}-contrastpeaks.csv'
data = TableData(data_file)
return data
"""
if data_file.exists():
data = TableData(data_file)
return data
dt = 0.0001
tmax = nfft*dt
alphas = np.linspace(0, alpha_max, 200)
ampl_f1 = np.zeros(len(alphas))
ampl_f2 = np.zeros(len(alphas))
ampl_sum = np.zeros(len(alphas))
ampl_diff = np.zeros(len(alphas))
for k, alpha in enumerate(alphas):
print(alpha)
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
freqs, psd = compute_power(spikes, nfft, dt)
ampl_f1[k] = peak_ampl(freqs, psd, beatf1)
ampl_f2[k] = peak_ampl(freqs, psd, beatf2)
ampl_sum[k] = peak_ampl(freqs, psd, beatf1 + beatf2)
ampl_diff[k] = peak_ampl(freqs, psd, beatf2 - beatf1)
data = TableData()
data.append('contrast', '%', '%.1f', 100*alphas)
data.append('f1', 'Hz', '%g', ampl_f1)
data.append('f2', 'Hz', '%g', ampl_f2)
data.append('f1+f2', 'Hz', '%g', ampl_sum)
data.append('f2-f1', 'Hz', '%g', ampl_diff)
data.write(data_file)
return data
"""
def amplitude(power):
power -= power[0]
power[power<0] = 0
return np.sqrt(power)
def amplitude_linearfit(contrast, power, max_contrast):
power -= power[0]
power[power<0] = 0
ampl = np.sqrt(power)
a = ampl[contrast <= max_contrast]
c = contrast[contrast <= max_contrast]
r = linregress(c, a)
return r.intercept + r.slope*contrast
def amplitude_squarefit(contrast, power, max_contrast):
power -= power[0]
power[power<0] = 0
ampl = np.sqrt(power)
a = np.sqrt(ampl[contrast <= max_contrast])
c = contrast[contrast <= max_contrast]
r = linregress(c, a)
return (r.intercept + r.slope*contrast)**2
def plot_peaks(ax, s, data, alphas):
contrast = data[:, 'contrast']
ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f1'], 4), **s.lsF01m)
ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f2'], 2), **s.lsF02m)
ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f1+f2'], 4), **s.lsF012m)
ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f2-f1'], 4), **s.lsF01_2m)
ax.plot(contrast, amplitude(data[:, 'f1']), **s.lsF01)
ax.plot(contrast, amplitude(data[:, 'f2']), **s.lsF02)
ax.plot(contrast, amplitude(data[:, 'f1+f2']), **s.lsF012)
ax.plot(contrast, amplitude(data[:, 'f2-f1']), **s.lsF01_2)
for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']):
contrast = 100*alpha
ax.plot(contrast, 630, 'vk', ms=4, clip_on=False)
ax.text(contrast, 660, tag, ha='center')
#ax.axvline(contrast, **s.lsGrid)
#ax.text(contrast, 630, tag, ha='center')
ax.axvline(1.5, **s.lsLine)
ax.axvline(4, **s.lsLine)
yoffs = 340
ax.text(1.5/2, yoffs, 'linear\nregime',
ha='center', va='center')
ax.text((1.5 + 4)/2, yoffs, 'weakly\nnonlinear\nregime',
ha='center', va='center')
ax.text(10, yoffs, 'strongly\nnonlinear\nregime',
ha='center', va='center')
ax.set_xlim(0, 16.5)
ax.set_ylim(0, 600)
ax.set_xticks_delta(5)
ax.set_yticks_delta(300)
ax.set_xlabel('Contrast', r'\%')
ax.set_ylabel('Amplitude', 'Hz')
if __name__ == '__main__':
parameters = load_models(data_path / 'punitmodels.csv')
cell_name = '2013-01-08-aa-invivo-1' # 138Hz, CV=0.26: perfect!
beatf1 = 40
beatf2 = 138
# cell_name = '2012-07-03-ak-invivo-1' # 128Hz, CV=0.24
# cell_name = '2018-05-08-ae-invivo-1' # 142Hz, CV=0.48
cell = cell_parameters(parameters, cell_name)
s = plot_style()
s.lwmid = 1.0
s.lwthick = 1.6
s.lsStim = dict(color='gray', lw=s.lwmid)
s.lsRaster = dict(color='black', lw=s.lwthin)
s.lsPower = dict(color='gray', lw=s.lwmid)
s.lsF0 = dict(color='blue', lw=s.lwthick)
s.lsF01 = dict(color='green', lw=s.lwthick)
s.lsF02 = dict(color='purple', lw=s.lwthick)
s.lsF012 = dict(color='orange', lw=s.lwthick)
s.lsF01_2 = dict(color='red', lw=s.lwthick)
s.lsF0m = dict(color=lighter('blue', 0.5), lw=s.lwthin)
s.lsF01m = dict(color=lighter('green', 0.6), lw=s.lwthin)
s.lsF02m = dict(color=lighter('purple', 0.5), lw=s.lwthin)
s.lsF012m = dict(color=darker('orange', 0.9), lw=s.lwthin)
s.lsF01_2m = dict(color=darker('red', 0.9), lw=s.lwthin)
s.psF0 = dict(color='blue', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
s.psF01 = dict(color='green', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
s.psF02 = dict(color='purple', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
s.psF012 = dict(color='orange', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
s.psF01_2 = dict(color='red', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
nfft = 2**18
fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.8*s.plot_width),
height_ratios=[1, 1.5, 2, 1.5, 4])
fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5,
wspace=0.3, hspace=0.3)
ax0 = fig.merge(axs[3, :])
ax0.set_visible(False)
axa = fig.merge(axs[4, :])
fig.show_spines('lb')
alphas = [0.01, 0.03, 0.05, 0.16]
#alphas = [0.002, 0.01, 0.05, 0.1]
for c, alpha in enumerate(alphas):
plot_example(axs[0, c], axs[1, c], axs[2, c], s, cell,
alpha, beatf1, beatf2, nfft, 100)
axs[1, 0].xscalebar(1, -0.1, 30, 'ms', ha='right')
axs[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8),
ncol=5, columnspacing=2)
data = compute_peaks(cell_name, cell, 0.2, beatf1, beatf2, nfft, 1000)
plot_peaks(axa, s, data, alphas)
fig.common_yspines(axs[0, :])
fig.common_yticks(axs[2, :])
#fig.common_xlabels(axs[2, :])
fig.tag(axs[0, :], xoffs=-2, yoffs=1.6)
fig.tag(axa)
fig.savefig()