updated regimes fiugre

This commit is contained in:
2025-05-19 10:09:01 +02:00
parent aeb840a4e7
commit 2e3d373100
6 changed files with 121 additions and 143 deletions

View File

@@ -11,6 +11,21 @@ data_path = Path('data')
sims_path = data_path / 'simulations'
def load_data(file_path):
data = np.load(file_path)
ratebase = float(data['ratebase'])
cvbase = float(data['cvbase'])
beatf1 = float(data['beatf1'])
beatf2 = float(data['beatf2'])
contrasts = data['contrasts']
powerf1 = data['powerf1']
powerf2 = data['powerf2']
powerfsum = data['powerfsum']
powerfdiff = data['powerfdiff']
return (ratebase, cvbase, beatf1, beatf2,
contrasts, powerf1, powerf2, powerfsum, powerfdiff)
def load_models(file):
""" Load model parameter from csv file.
@@ -25,7 +40,7 @@ def load_models(file):
For each cell a dictionary with model parameters.
"""
parameters = []
with file.open('r') as file:
with open(file, 'r') as file:
header_line = file.readline()
header_parts = header_line.strip().split(",")
keys = header_parts
@@ -118,7 +133,8 @@ def plot_am(ax, s, alpha, beatf1, beatf2, tmax):
ax.show_spines('l')
ax.plot(1000*time, -100*am, **s.lsAM)
ax.set_xlim(0, 1000*tmax)
ax.set_ylim(-50, 50)
ax.set_ylim(-13, 13)
ax.set_yticks_delta(10)
#ax.set_xlabel('Time', 'ms')
ax.set_ylabel('AM', r'\%')
ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%',
@@ -134,59 +150,64 @@ def plot_raster(ax, s, spikes, tmax):
#ax.set_ylabel('Trials')
def compute_power(spikes, nfft, dt):
psds = []
time = np.arange(nfft + 1)*dt
tmax = nfft*dt
rates = []
cvs = []
for s in spikes:
rates.append(len(s)/tmax)
isis = np.diff(s)
cvs.append(np.std(isis)/np.mean(isis))
b, _ = np.histogram(s, time)
fourier = np.fft.rfft(b - np.mean(b))
psds.append(np.abs(fourier)**2)
#psds.append(fourier)
freqs = np.fft.rfftfreq(nfft, dt)
#print('mean rate', np.mean(rates))
#print('CV', np.mean(cvs))
return freqs, np.mean(psds, 0)
#return freqs, np.abs(np.mean(psds, 0))**2/dt
def compute_power(path, contrast, spikes, nfft, dt):
if not path.exists():
print(f' Compute power spectrum for contrast = {100*contrast:4.1f}%')
psds = []
time = np.arange(nfft + 1)*dt
tmax = nfft*dt
for s in spikes:
b, _ = np.histogram(s, time)
b = b / dt
fourier = np.fft.rfft(b - np.mean(b))
psds.append(np.abs(fourier)**2)
freqs = np.fft.rfftfreq(nfft, dt)
prr = np.mean(psds, 0)*dt/nfft
np.savez(path, nfft=nfft, deltat=dt, nsegs=len(spikes),
freqs=freqs, prr=prr)
else:
print(f' Load power spectrum for contrast = {100*contrast:4.1f}%')
data = np.load(path)
freqs = data['freqs']
prr = data['prr']
return freqs, prr
def decibel(x):
return 10*np.log10(x/1e8)
def plot_psd(ax, s, spikes, nfft, dt, beatf1, beatf2):
offs = 3
freqs, psd = compute_power(spikes, nfft, dt)
def plot_psd(ax, s, path, contrast, spikes, nfft, dt, beatf1, beatf2):
offs = 4
freqs, psd = compute_power(path, contrast, spikes, nfft, dt)
psd /= freqs[1]
ax.plot(freqs, decibel(psd), **s.lsPower)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs,
label=r'$r$', clip_on=False, **s.psF0)
ax.plot(beatf1, decibel(peak_ampl(freqs, psd, beatf1)) + offs,
label=r'$\Delta f_1$', clip_on=False, **s.psF01)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs + 4.5,
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs + 5.5,
label=r'$\Delta f_2$', clip_on=False, **s.psF02)
ax.plot(beatf2 - beatf1, decibel(peak_ampl(freqs, psd, beatf2 - beatf1)) + offs,
label=r'$\Delta f_2 - \Delta f_1$', clip_on=False, **s.psF01_2)
ax.plot(beatf1 + beatf2, decibel(peak_ampl(freqs, psd, beatf1 + beatf2)) + offs,
label=r'$\Delta f_1 + \Delta f_2$', clip_on=False, **s.psF012)
ax.set_xlim(0, 300)
ax.set_ylim(-40, 0)
ax.set_ylim(-60, 0)
ax.set_xlabel('Frequency', 'Hz')
ax.set_ylabel('Power [dB]')
def plot_example(axs, axr, axp, s, cell, alpha, beatf1, beatf2, nfft, trials):
def plot_example(axs, axr, axp, s, path, cell, alpha, beatf1, beatf2,
nfft, trials):
sim_path = path / f'{cell_name}-contrastspectrum-{1000*alpha:03.0f}.npz'
dt = 0.0001
tmax = nfft*dt
t1 = 0.1
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
plot_am(axs, s, alpha, beatf1, beatf2, t1)
plot_raster(axr, s, spikes, t1)
plot_psd(axp, s, spikes, nfft, dt, beatf1, beatf2)
plot_psd(axp, s, sim_path, alpha, spikes, nfft, dt, beatf1, beatf2)
def peak_ampl(freqs, psd, f):
@@ -195,39 +216,6 @@ def peak_ampl(freqs, psd, f):
return np.max(psd_snippet)
def compute_peaks(name, cell, alpha_max, beatf1, beatf2, nfft, trials):
data_file = sims_path / f'{name}-contrastpeaks.csv'
data = TableData(data_file)
return data
"""
if data_file.exists():
data = TableData(data_file)
return data
dt = 0.0001
tmax = nfft*dt
alphas = np.linspace(0, alpha_max, 200)
ampl_f1 = np.zeros(len(alphas))
ampl_f2 = np.zeros(len(alphas))
ampl_sum = np.zeros(len(alphas))
ampl_diff = np.zeros(len(alphas))
for k, alpha in enumerate(alphas):
print(alpha)
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
freqs, psd = compute_power(spikes, nfft, dt)
ampl_f1[k] = peak_ampl(freqs, psd, beatf1)
ampl_f2[k] = peak_ampl(freqs, psd, beatf2)
ampl_sum[k] = peak_ampl(freqs, psd, beatf1 + beatf2)
ampl_diff[k] = peak_ampl(freqs, psd, beatf2 - beatf1)
data = TableData()
data.append('contrast', '%', '%.1f', 100*alphas)
data.append('f1', 'Hz', '%g', ampl_f1)
data.append('f2', 'Hz', '%g', ampl_f2)
data.append('f1+f2', 'Hz', '%g', ampl_sum)
data.append('f2-f1', 'Hz', '%g', ampl_diff)
data.write(data_file)
return data
"""
def amplitude(power):
power -= power[0]
power[power<0] = 0
@@ -254,73 +242,83 @@ def amplitude_squarefit(contrast, power, max_contrast):
return (r.intercept + r.slope*contrast)**2
def plot_peaks(ax, s, data, alphas):
contrast = data[:, 'contrast']
ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f1'], 4), **s.lsF01m)
ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f2'], 2), **s.lsF02m)
ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f1+f2'], 4), **s.lsF012m)
ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f2-f1'], 4), **s.lsF01_2m)
ax.plot(contrast, amplitude(data[:, 'f1']), **s.lsF01)
ax.plot(contrast, amplitude(data[:, 'f2']), **s.lsF02)
ax.plot(contrast, amplitude(data[:, 'f1+f2']), **s.lsF012)
ax.plot(contrast, amplitude(data[:, 'f2-f1']), **s.lsF01_2)
def plot_peaks(ax, s, alphas, contrasts, powerf1, powerf2, powerfsum,
powerfdiff):
cmax = 10
contrasts *= 100
ax.plot(contrasts, amplitude_linearfit(contrasts, powerf1, 4),
**s.lsF01m)
ax.plot(contrasts, amplitude_linearfit(contrasts, powerf2, 2),
**s.lsF02m)
ax.plot(contrasts, amplitude_squarefit(contrasts, powerfsum, 4),
**s.lsF012m)
ax.plot(contrasts, amplitude_squarefit(contrasts, powerfdiff, 4),
**s.lsF01_2m)
ax.plot(contrasts, amplitude(powerf1), **s.lsF01)
ax.plot(contrasts, amplitude(powerf2), **s.lsF02)
mask = contrasts < cmax
ax.plot(contrasts[mask], amplitude(powerfsum)[mask],
clip_on=False, **s.lsF012)
ax.plot(contrasts[mask], amplitude(powerfdiff)[mask],
clip_on=False, **s.lsF01_2)
ymax = 60
for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']):
contrast = 100*alpha
ax.plot(contrast, 630, 'vk', ms=4, clip_on=False)
ax.text(contrast, 660, tag, ha='center')
ax.plot(100*alpha, ymax*0.95, 'vk', ms=4, clip_on=False)
ax.text(100*alpha, ymax, tag, ha='center')
#ax.axvline(contrast, **s.lsGrid)
#ax.text(contrast, 630, tag, ha='center')
ax.axvline(1.5, **s.lsLine)
ax.axvline(4, **s.lsLine)
yoffs = 340
ax.text(1.5/2, yoffs, 'linear\nregime',
ax.axvline(1.2, **s.lsLine)
ax.axvline(3.5, **s.lsLine)
yoffs = 35
ax.text(1.2/2, yoffs, 'linear\nregime',
ha='center', va='center')
ax.text((1.5 + 4)/2, yoffs, 'weakly\nnonlinear\nregime',
ax.text((1.2 + 3.5)/2, yoffs, 'weakly\nnonlinear\nregime',
ha='center', va='center')
ax.text(10, yoffs, 'strongly\nnonlinear\nregime',
ax.text(5.5, yoffs, 'strongly\nnonlinear\nregime',
ha='center', va='center')
ax.set_xlim(0, 16.5)
ax.set_ylim(0, 600)
ax.set_xticks_delta(5)
ax.set_yticks_delta(300)
ax.set_xlim(0, cmax)
ax.set_ylim(0, ymax)
ax.set_xticks_delta(2)
ax.set_yticks_delta(20)
ax.set_xlabel('Contrast', r'\%')
ax.set_ylabel('Amplitude', 'Hz')
if __name__ == '__main__':
cell_name = '2018-05-08-ad-invivo-1' # 228Hz, CV=0.67
ratebase, cvbase, beatf1, beatf2, \
contrasts, powerf1, powerf2, powerfsum, powerfdiff = \
load_data(sims_path / f'{cell_name}-contrastpeaks.npz')
alphas = [0.002, 0.01, 0.03, 0.06]
parameters = load_models(data_path / 'punitmodels.csv')
cell_name = '2013-01-08-aa-invivo-1' # 132Hz, CV=0.16: perfect!
beatf1 = 40
beatf2 = 132
# cell_name = '2012-07-03-ak-invivo-1' # 128Hz, CV=0.24
# cell_name = '2018-05-08-ae-invivo-1' # 142Hz, CV=0.48
cell = cell_parameters(parameters, cell_name)
nfft = 2**18
print(f'Loaded data for cell {cell_name}: '
f'baseline rate = {ratebase:.0f}Hz, CV = {cvbase:.2f}')
s = plot_style()
nfft = 2**18
fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.8*s.plot_width),
height_ratios=[1, 1.5, 2, 1.5, 4])
fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5,
wspace=0.3, hspace=0.3)
ax0 = fig.merge(axs[3, :])
ax0.set_visible(False)
axa = fig.merge(axs[4, :])
fig, (axes, axa) = plt.subplots(2, 1, height_ratios=[4, 3],
cmsize=(s.plot_width, 0.6*s.plot_width))
fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5, hspace=0.6)
axe = axes.subplots(3, 4, wspace=0.4, hspace=0.2,
height_ratios=[1, 2, 3])
fig.show_spines('lb')
alphas = [0.01, 0.03, 0.05, 0.16]
#alphas = [0.002, 0.01, 0.05, 0.1]
# example power spectra:
for c, alpha in enumerate(alphas):
plot_example(axs[0, c], axs[1, c], axs[2, c], s, cell,
alpha, beatf1, beatf2, nfft, 100)
axs[1, 0].xscalebar(1, -0.1, 30, 'ms', ha='right')
axs[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8),
plot_example(axe[0, c], axe[1, c], axe[2, c], s, sims_path,
cell, alpha, beatf1, beatf2, nfft, 100)
axe[1, 0].xscalebar(1, -0.1, 20, 'ms', ha='right')
axe[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8),
ncol=5, columnspacing=2)
data = compute_peaks(cell_name, cell, 0.2, beatf1, beatf2, nfft, 1000)
plot_peaks(axa, s, data, alphas)
fig.common_yspines(axs[0, :])
fig.common_yticks(axs[2, :])
#fig.common_xlabels(axs[2, :])
fig.tag(axs[0, :], xoffs=-2, yoffs=1.6)
fig.tag(axa)
fig.common_yspines(axe[0, :])
fig.common_yticks(axe[2, :])
fig.tag(axe[0, :], xoffs=-3, yoffs=1.6)
# contrast dependence:
plot_peaks(axa, s, alphas, contrasts, powerf1, powerf2,
powerfsum, powerfdiff)
fig.tag(axa, yoffs=2)
fig.savefig()