327 lines
11 KiB
Python
327 lines
11 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
from scipy.stats import linregress
|
|
from numba import jit
|
|
from thunderlab.tabledata import TableData
|
|
from plotstyle import plot_style, lighter, darker
|
|
|
|
|
|
data_path = Path('data')
|
|
sims_path = data_path / 'simulations'
|
|
|
|
|
|
def load_models(file):
|
|
""" Load model parameter from csv file.
|
|
|
|
Parameters
|
|
----------
|
|
file: string
|
|
Name of file with model parameters.
|
|
|
|
Returns
|
|
-------
|
|
parameters: list of dict
|
|
For each cell a dictionary with model parameters.
|
|
"""
|
|
parameters = []
|
|
with file.open('r') as file:
|
|
header_line = file.readline()
|
|
header_parts = header_line.strip().split(",")
|
|
keys = header_parts
|
|
for line in file:
|
|
line_parts = line.strip().split(",")
|
|
parameter = {}
|
|
for i in range(len(keys)):
|
|
parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i]
|
|
parameters.append(parameter)
|
|
return parameters
|
|
|
|
|
|
def cell_parameters(parameters, cell_name):
|
|
for params in parameters:
|
|
if params['cell'] == cell_name:
|
|
return params
|
|
print('cell', cell_name, 'not found!')
|
|
exit()
|
|
return None
|
|
|
|
|
|
@jit(nopython=True)
|
|
def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0,
|
|
threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1,
|
|
v_offset=-10.0, mem_tau=0.015, noise_strength=0.05,
|
|
input_scaling=60.0, dend_tau=0.001, ref_period=0.001):
|
|
""" Simulate a P-unit.
|
|
|
|
Returns
|
|
-------
|
|
spike_times: 1-D array
|
|
Simulated spike times in seconds.
|
|
"""
|
|
# initial conditions:
|
|
v_dend = stimulus[0]
|
|
v_mem = v_zero
|
|
adapt = a_zero
|
|
|
|
# prepare noise:
|
|
noise = np.random.randn(len(stimulus))
|
|
noise *= noise_strength / np.sqrt(deltat)
|
|
|
|
# rectify stimulus array:
|
|
stimulus = stimulus.copy()
|
|
stimulus[stimulus < 0.0] = 0.0
|
|
|
|
# integrate:
|
|
spike_times = []
|
|
for i in range(len(stimulus)):
|
|
v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat
|
|
v_mem += (v_base - v_mem + v_offset + (
|
|
v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat
|
|
adapt += -adapt / tau_a * deltat
|
|
|
|
# refractory period:
|
|
if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2:
|
|
v_mem = v_base
|
|
|
|
# threshold crossing:
|
|
if v_mem > threshold:
|
|
v_mem = v_base
|
|
spike_times.append(i * deltat)
|
|
adapt += delta_a / tau_a
|
|
|
|
return np.array(spike_times)
|
|
|
|
|
|
def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials):
|
|
tini = 0.2
|
|
model_params = dict(parameter)
|
|
cell = model_params.pop('cell')
|
|
eodf0 = model_params.pop('EODf')
|
|
time = np.arange(-tini, tmax, model_params['deltat'])
|
|
stimulus = np.sin(2*np.pi*eodf0*time)
|
|
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time)
|
|
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time)
|
|
spikes = []
|
|
for i in range(trials):
|
|
model_params['v_zero'] = np.random.rand()
|
|
model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn()
|
|
spiket = simulate(stimulus, **model_params)
|
|
spikes.append(spiket[spiket > tini] - tini)
|
|
return spikes
|
|
|
|
|
|
def plot_am(ax, s, alpha, beatf1, beatf2, tmax):
|
|
time = np.arange(0, tmax, 0.0001)
|
|
am = alpha*np.sin(2*np.pi*beatf1*time)
|
|
am += alpha*np.sin(2*np.pi*beatf2*time)
|
|
ax.show_spines('l')
|
|
ax.plot(1000*time, -100*am, **s.lsAM)
|
|
ax.set_xlim(0, 1000*tmax)
|
|
ax.set_ylim(-50, 50)
|
|
#ax.set_xlabel('Time', 'ms')
|
|
ax.set_ylabel('AM', r'\%')
|
|
ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%',
|
|
transform=ax.transAxes, ha='right')
|
|
|
|
|
|
def plot_raster(ax, s, spikes, tmax):
|
|
spikes_ms = [1000*s[s<tmax] for s in spikes[:16]]
|
|
ax.show_spines('')
|
|
ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
|
|
ax.set_xlim(0, 1000*tmax)
|
|
#ax.set_xlabel('Time', 'ms')
|
|
#ax.set_ylabel('Trials')
|
|
|
|
|
|
def compute_power(spikes, nfft, dt):
|
|
psds = []
|
|
time = np.arange(nfft + 1)*dt
|
|
tmax = nfft*dt
|
|
rates = []
|
|
cvs = []
|
|
for s in spikes:
|
|
rates.append(len(s)/tmax)
|
|
isis = np.diff(s)
|
|
cvs.append(np.std(isis)/np.mean(isis))
|
|
b, _ = np.histogram(s, time)
|
|
fourier = np.fft.rfft(b - np.mean(b))
|
|
psds.append(np.abs(fourier)**2)
|
|
#psds.append(fourier)
|
|
freqs = np.fft.rfftfreq(nfft, dt)
|
|
#print('mean rate', np.mean(rates))
|
|
#print('CV', np.mean(cvs))
|
|
return freqs, np.mean(psds, 0)
|
|
#return freqs, np.abs(np.mean(psds, 0))**2/dt
|
|
|
|
|
|
def decibel(x):
|
|
return 10*np.log10(x/1e8)
|
|
|
|
def plot_psd(ax, s, spikes, nfft, dt, beatf1, beatf2):
|
|
offs = 3
|
|
freqs, psd = compute_power(spikes, nfft, dt)
|
|
psd /= freqs[1]
|
|
ax.plot(freqs, decibel(psd), **s.lsPower)
|
|
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs,
|
|
label=r'$r$', clip_on=False, **s.psF0)
|
|
ax.plot(beatf1, decibel(peak_ampl(freqs, psd, beatf1)) + offs,
|
|
label=r'$\Delta f_1$', clip_on=False, **s.psF01)
|
|
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs + 4.5,
|
|
label=r'$\Delta f_2$', clip_on=False, **s.psF02)
|
|
ax.plot(beatf2 - beatf1, decibel(peak_ampl(freqs, psd, beatf2 - beatf1)) + offs,
|
|
label=r'$\Delta f_2 - \Delta f_1$', clip_on=False, **s.psF01_2)
|
|
ax.plot(beatf1 + beatf2, decibel(peak_ampl(freqs, psd, beatf1 + beatf2)) + offs,
|
|
label=r'$\Delta f_1 + \Delta f_2$', clip_on=False, **s.psF012)
|
|
ax.set_xlim(0, 300)
|
|
ax.set_ylim(-40, 0)
|
|
ax.set_xlabel('Frequency', 'Hz')
|
|
ax.set_ylabel('Power [dB]')
|
|
|
|
|
|
def plot_example(axs, axr, axp, s, cell, alpha, beatf1, beatf2, nfft, trials):
|
|
dt = 0.0001
|
|
tmax = nfft*dt
|
|
t1 = 0.1
|
|
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
|
|
plot_am(axs, s, alpha, beatf1, beatf2, t1)
|
|
plot_raster(axr, s, spikes, t1)
|
|
plot_psd(axp, s, spikes, nfft, dt, beatf1, beatf2)
|
|
|
|
|
|
def peak_ampl(freqs, psd, f):
|
|
df = 2
|
|
psd_snippet = psd[(freqs > f - df) & (freqs < f + df)]
|
|
return np.max(psd_snippet)
|
|
|
|
|
|
def compute_peaks(name, cell, alpha_max, beatf1, beatf2, nfft, trials):
|
|
data_file = sims_path / f'{name}-contrastpeaks.csv'
|
|
data = TableData(data_file)
|
|
return data
|
|
"""
|
|
if data_file.exists():
|
|
data = TableData(data_file)
|
|
return data
|
|
dt = 0.0001
|
|
tmax = nfft*dt
|
|
alphas = np.linspace(0, alpha_max, 200)
|
|
ampl_f1 = np.zeros(len(alphas))
|
|
ampl_f2 = np.zeros(len(alphas))
|
|
ampl_sum = np.zeros(len(alphas))
|
|
ampl_diff = np.zeros(len(alphas))
|
|
for k, alpha in enumerate(alphas):
|
|
print(alpha)
|
|
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
|
|
freqs, psd = compute_power(spikes, nfft, dt)
|
|
ampl_f1[k] = peak_ampl(freqs, psd, beatf1)
|
|
ampl_f2[k] = peak_ampl(freqs, psd, beatf2)
|
|
ampl_sum[k] = peak_ampl(freqs, psd, beatf1 + beatf2)
|
|
ampl_diff[k] = peak_ampl(freqs, psd, beatf2 - beatf1)
|
|
data = TableData()
|
|
data.append('contrast', '%', '%.1f', 100*alphas)
|
|
data.append('f1', 'Hz', '%g', ampl_f1)
|
|
data.append('f2', 'Hz', '%g', ampl_f2)
|
|
data.append('f1+f2', 'Hz', '%g', ampl_sum)
|
|
data.append('f2-f1', 'Hz', '%g', ampl_diff)
|
|
data.write(data_file)
|
|
return data
|
|
"""
|
|
|
|
def amplitude(power):
|
|
power -= power[0]
|
|
power[power<0] = 0
|
|
return np.sqrt(power)
|
|
|
|
|
|
def amplitude_linearfit(contrast, power, max_contrast):
|
|
power -= power[0]
|
|
power[power<0] = 0
|
|
ampl = np.sqrt(power)
|
|
a = ampl[contrast <= max_contrast]
|
|
c = contrast[contrast <= max_contrast]
|
|
r = linregress(c, a)
|
|
return r.intercept + r.slope*contrast
|
|
|
|
|
|
def amplitude_squarefit(contrast, power, max_contrast):
|
|
power -= power[0]
|
|
power[power<0] = 0
|
|
ampl = np.sqrt(power)
|
|
a = np.sqrt(ampl[contrast <= max_contrast])
|
|
c = contrast[contrast <= max_contrast]
|
|
r = linregress(c, a)
|
|
return (r.intercept + r.slope*contrast)**2
|
|
|
|
|
|
def plot_peaks(ax, s, data, alphas):
|
|
contrast = data[:, 'contrast']
|
|
ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f1'], 4), **s.lsF01m)
|
|
ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f2'], 2), **s.lsF02m)
|
|
ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f1+f2'], 4), **s.lsF012m)
|
|
ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f2-f1'], 4), **s.lsF01_2m)
|
|
ax.plot(contrast, amplitude(data[:, 'f1']), **s.lsF01)
|
|
ax.plot(contrast, amplitude(data[:, 'f2']), **s.lsF02)
|
|
ax.plot(contrast, amplitude(data[:, 'f1+f2']), **s.lsF012)
|
|
ax.plot(contrast, amplitude(data[:, 'f2-f1']), **s.lsF01_2)
|
|
for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']):
|
|
contrast = 100*alpha
|
|
ax.plot(contrast, 630, 'vk', ms=4, clip_on=False)
|
|
ax.text(contrast, 660, tag, ha='center')
|
|
#ax.axvline(contrast, **s.lsGrid)
|
|
#ax.text(contrast, 630, tag, ha='center')
|
|
ax.axvline(1.5, **s.lsLine)
|
|
ax.axvline(4, **s.lsLine)
|
|
yoffs = 340
|
|
ax.text(1.5/2, yoffs, 'linear\nregime',
|
|
ha='center', va='center')
|
|
ax.text((1.5 + 4)/2, yoffs, 'weakly\nnonlinear\nregime',
|
|
ha='center', va='center')
|
|
ax.text(10, yoffs, 'strongly\nnonlinear\nregime',
|
|
ha='center', va='center')
|
|
ax.set_xlim(0, 16.5)
|
|
ax.set_ylim(0, 600)
|
|
ax.set_xticks_delta(5)
|
|
ax.set_yticks_delta(300)
|
|
ax.set_xlabel('Contrast', r'\%')
|
|
ax.set_ylabel('Amplitude', 'Hz')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parameters = load_models(data_path / 'punitmodels.csv')
|
|
cell_name = '2013-01-08-aa-invivo-1' # 132Hz, CV=0.16: perfect!
|
|
beatf1 = 40
|
|
beatf2 = 132
|
|
# cell_name = '2012-07-03-ak-invivo-1' # 128Hz, CV=0.24
|
|
# cell_name = '2018-05-08-ae-invivo-1' # 142Hz, CV=0.48
|
|
|
|
cell = cell_parameters(parameters, cell_name)
|
|
|
|
s = plot_style()
|
|
|
|
nfft = 2**18
|
|
fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.8*s.plot_width),
|
|
height_ratios=[1, 1.5, 2, 1.5, 4])
|
|
fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5,
|
|
wspace=0.3, hspace=0.3)
|
|
ax0 = fig.merge(axs[3, :])
|
|
ax0.set_visible(False)
|
|
axa = fig.merge(axs[4, :])
|
|
fig.show_spines('lb')
|
|
alphas = [0.01, 0.03, 0.05, 0.16]
|
|
#alphas = [0.002, 0.01, 0.05, 0.1]
|
|
for c, alpha in enumerate(alphas):
|
|
plot_example(axs[0, c], axs[1, c], axs[2, c], s, cell,
|
|
alpha, beatf1, beatf2, nfft, 100)
|
|
axs[1, 0].xscalebar(1, -0.1, 30, 'ms', ha='right')
|
|
axs[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8),
|
|
ncol=5, columnspacing=2)
|
|
data = compute_peaks(cell_name, cell, 0.2, beatf1, beatf2, nfft, 1000)
|
|
plot_peaks(axa, s, data, alphas)
|
|
fig.common_yspines(axs[0, :])
|
|
fig.common_yticks(axs[2, :])
|
|
#fig.common_xlabels(axs[2, :])
|
|
fig.tag(axs[0, :], xoffs=-2, yoffs=1.6)
|
|
fig.tag(axa)
|
|
fig.savefig()
|