342 lines
12 KiB
Python
342 lines
12 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
from pathlib import Path
|
|
from scipy.stats import norm
|
|
from scipy.optimize import curve_fit
|
|
from spectral import rate
|
|
from plotstyle import plot_style
|
|
|
|
|
|
cell = '2021-08-03-ac-invivo-1'
|
|
|
|
data_path = Path('data')
|
|
|
|
sigma = 0.002
|
|
|
|
|
|
def load_spikes(cell_path, f1=797, f2=631):
|
|
load = False
|
|
spikes = []
|
|
index = 0
|
|
with open(cell_path / 'threefish-spikes.dat') as sf:
|
|
for line in sf:
|
|
if load:
|
|
if ' before:' in line:
|
|
t0 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
|
elif ' duration1 ' in line:
|
|
t1 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
|
elif ' duration2 ' in line:
|
|
t2 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
|
elif ' duration12 ' in line:
|
|
t12 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
|
elif line.startswith('# index '):
|
|
if len(spikes) > 0:
|
|
spikes[-1] = np.array(spikes[-1])
|
|
return spikes, eodf, df1, df2, t0, t1, t2, t12, index
|
|
elif line.startswith('# trial:'):
|
|
if len(spikes) > 0:
|
|
spikes[-1] = np.array(spikes[-1])
|
|
spikes.append([])
|
|
elif len(line.strip()) > 0 and line[0] != '#':
|
|
t = 0.001*float(line.strip())
|
|
spikes[-1].append(t)
|
|
elif line.startswith('# index '):
|
|
index += 1
|
|
elif line.startswith('# EOD rate '):
|
|
eodf = float(line.split(':')[1].strip().replace('Hz', ''))
|
|
elif line.startswith('# Deltaf1 '):
|
|
df1 = float(line.split(':')[1].strip().replace('Hz', ''))
|
|
elif line.startswith('# Deltaf2 '):
|
|
df2 = float(line.split(':')[1].strip().replace('Hz', ''))
|
|
if abs(eodf + df1 - f1) < 1 and abs(eodf + df2 - f2) < 1:
|
|
#print(f'EODf={eodf:6.1f}Hz, Df1={df1:6.1f}Hz, Df2={df2:6.1f}Hz, EODf1={eodf + df1:6.1f}Hz, EODf2={eodf + df2:6.1f}Hz')
|
|
load = True
|
|
print(f'no spikes found for EODf1={f1:.1f}Hz and EODf2={f2:.1f}Hz')
|
|
|
|
|
|
def load_am(cell_path, inx):
|
|
load = False
|
|
ams = []
|
|
index = 0
|
|
with open(cell_path / 'threefish-ams.dat') as sf:
|
|
for line in sf:
|
|
if load:
|
|
if line.startswith('# index '):
|
|
if len(ams) > 0:
|
|
ams[-1] = np.array(ams[-1])
|
|
return ams
|
|
elif line.startswith('# EOD rate '):
|
|
print(f' EODf = {line.split(':')[1].strip()}')
|
|
elif line.startswith('# Deltaf1 '):
|
|
print(f' Df1 = {line.split(':')[1].strip()}')
|
|
elif line.startswith('# Deltaf2 '):
|
|
print(f' DF2 = {line.split(':')[1].strip()}')
|
|
elif line.startswith('# trial:'):
|
|
if len(ams) > 0:
|
|
ams[-1] = np.array(ams[-1])
|
|
ams.append([])
|
|
elif len(line.strip()) > 0 and line[0] != '#':
|
|
time, am = line.split()
|
|
t = 0.001*float(time.strip())
|
|
a = float(am.strip())
|
|
ams[-1].append((t, a))
|
|
elif line.startswith('# index '):
|
|
index += 1
|
|
if inx == index:
|
|
load = True
|
|
print(f'no AM found at index {inx}')
|
|
|
|
|
|
def cosine(x, a, f, p, c):
|
|
return a*np.cos(2*np.pi*f*x + p) + c
|
|
|
|
|
|
def two_cosine(x, a1, f1, p1, a2, f2, p2, c):
|
|
return a1*np.cos(2*np.pi*f1*x + p1) + a2*np.cos(2*np.pi*f2*x + p2) + c
|
|
|
|
|
|
def am_phases(ams, eodf, df1, df2, t1, t2, t12):
|
|
twins = (t1, t2, t12)
|
|
dfs = ((df1,), (df2,), (df1, df2))
|
|
phases = np.zeros((len(ams), len(dfs) + 1))
|
|
for k in range(len(ams)):
|
|
t0 = 0
|
|
time = ams[k][:, 0]
|
|
am = ams[k][:, 1]
|
|
for i in range(len(twins)):
|
|
tw = twins[0]
|
|
t1 = t0 + tw
|
|
mask = (time >= t0) & (time <= t1)
|
|
tam = time[mask] - t0
|
|
aam = am[mask]
|
|
a = 0.5*(np.max(aam) - np.min(aam))
|
|
c = np.mean(aam)
|
|
tt = np.linspace(0, tw, 1000)
|
|
if len(dfs[i]) == 2:
|
|
popt = [a/2, dfs[i][0], 0, a/2, dfs[i][1], 0, c]
|
|
popt, _ = curve_fit(two_cosine, tam, aam, popt)
|
|
aa = two_cosine(tt, *popt)
|
|
phases[k, i] = popt[2] if popt[0] > 0 else popt[2] + np.pi
|
|
phases[k, i + 1] = popt[5] if popt[3] > 0 else popt[5] + np.pi
|
|
else:
|
|
popt = [a, dfs[i][0], 0, c]
|
|
popt, _ = curve_fit(cosine, tam, aam, popt)
|
|
aa = cosine(tt, *popt)
|
|
phases[k, i] = popt[2] if popt[0] > 0 else popt[2] + np.pi
|
|
t0 = t1
|
|
return phases
|
|
|
|
|
|
def align_spikes(spikes, freqs, phases):
|
|
f1, f2 = freqs
|
|
if f1 is None and f2 is None:
|
|
return spikes
|
|
p1 = phases[0]
|
|
p2 = phases[1]
|
|
if f2 is None:
|
|
df = f1
|
|
p = p1
|
|
else:
|
|
df = f2
|
|
p = p2
|
|
for i in range(len(spikes)):
|
|
spikes[i] += p[i]/2/np.pi/df
|
|
return spikes
|
|
|
|
|
|
def baseline_rate(spikes, t0, t1):
|
|
rates = []
|
|
for times in spikes:
|
|
c = np.sum((times > t0) & (times < t1))
|
|
rates.append(c/(t1 - t0))
|
|
return np.mean(rates)
|
|
|
|
|
|
def power_spectrum(spikes, tmax, dt=0.0005, nfft=512, p_ref=4000):
|
|
time = np.arange(0, tmax, dt)
|
|
if nfft > len(time):
|
|
print('nfft too large:', nfft, len(time))
|
|
freqs = np.fft.fftfreq(nfft, dt)
|
|
freqs = np.fft.fftshift(freqs)
|
|
segments = range(0, len(time) - nfft, nfft)
|
|
p_rr = np.zeros(len(freqs))
|
|
n = 0
|
|
for i, spiket in enumerate(spikes):
|
|
b, _ = np.histogram(spiket, time)
|
|
b = b / dt
|
|
for j, k in enumerate(segments):
|
|
fourier_r = np.fft.fft(b[k:k + nfft] - np.mean(b), n=nfft)
|
|
fourier_r = np.fft.fftshift(fourier_r)
|
|
p_rr += np.abs(fourier_r*np.conj(fourier_r))
|
|
n += 1
|
|
mask = freqs >= 0.0
|
|
freqs = freqs[mask]
|
|
scale = dt/nfft/n
|
|
p_rr = p_rr[mask]*scale
|
|
power = 10*np.log10(p_rr/p_ref)
|
|
return freqs, power
|
|
|
|
|
|
def plot_symbols(ax, s, freqs):
|
|
f1, f2 = freqs
|
|
ax.show_spines('')
|
|
ax.add_artist(plt.Rectangle((-1, -0.5), 2, 1, color=s.colors['black']))
|
|
ax.harrow(1.6, 0, 1.3, **s.asLine)
|
|
ax.set_xlim(-6, 14)
|
|
ax.set_ylim(-1, 1)
|
|
if f1 is None and f2 is None:
|
|
ax.text(3.5, 0, '$r$', va='center')
|
|
else:
|
|
ax.harrow(-2.8, 0, 1.3, **s.asLine)
|
|
if f2 is None:
|
|
ax.text(-3.2, 0, '$s_1(t)$', ha='right', va='center')
|
|
ax.text(3.3, 0, '$r + r_1(t)$', va='center')
|
|
elif f1 is None:
|
|
ax.text(-3.2, 0, '$s_2(t)$', ha='right', va='center')
|
|
ax.text(3.3, 0, '$r + r_2(t)$', va='center')
|
|
else:
|
|
ax.text(-3.2, 0, '$s_1(t) + s_2(t)$', ha='right', va='center')
|
|
ax.text(3.3, 0, '$\\ne r + r_1(t) + r_2(t)$', va='center')
|
|
|
|
|
|
def plot_stimulus(ax, s, tmax, eodf, freqs, c=0.1):
|
|
time = np.arange(0, tmax, 0.0001)
|
|
eod = np.cos(2*np.pi*eodf*time)
|
|
am = np.ones(len(time))
|
|
ams = {}
|
|
f1, f2 = freqs
|
|
label = '$f_{EOD}$'
|
|
dp = np.pi
|
|
if f1 is not None:
|
|
eod += c*np.cos(2*np.pi*(eodf + f1)*time + dp)
|
|
am += c*np.cos(2*np.pi*f1*time + dp)
|
|
ams = s.lsF02
|
|
label += r' \& $f_1$'
|
|
if f2 is not None:
|
|
eod += c*np.cos(2*np.pi*(eodf + f2)*time + dp)
|
|
am += c*np.cos(2*np.pi*f2*time + dp)
|
|
ams = s.lsF01
|
|
label += r' \& $f_2$'
|
|
if f1 is not None and f2 is not None:
|
|
ams = s.lsF01_2
|
|
ax.show_spines('')
|
|
ax.plot(1000*time, eod, **s.lsEOD)
|
|
if len(ams) > 0:
|
|
ax.plot(1000*time, am, **ams)
|
|
ax.set_xlim(0, 1000*tmax)
|
|
ax.set_ylim(-1.02 - 2*c, 1.02 + 2*c)
|
|
ax.text(0.5, 1.2, label, ha='center', transform=ax.transAxes)
|
|
|
|
|
|
def plot_raster(ax, s, spikes, tmin, tmax):
|
|
spikes_ms = [1000*(s[(s > tmin) & (s < tmax)] - tmin) for s in spikes]
|
|
ax.show_spines('')
|
|
ax.eventplot(spikes_ms, linelengths=0.8, **s.lsRaster)
|
|
ax.set_xlim(0, 1000*(tmax - tmin))
|
|
|
|
|
|
def plot_rate(ax, s, spikes, tmin, tmax, sigma=0.002):
|
|
time = np.arange(0, tmin + tmax, sigma/4)
|
|
r, rsd = rate(time, spikes, sigma)
|
|
mask = (time >= tmin) & (time <= tmax)
|
|
time = time[mask] - tmin
|
|
r = r[mask]
|
|
ax.show_spines('l')
|
|
ax.plot(1000*time, r, clip_on=False, **s.lsRate)
|
|
ax.set_xlim(0, 1000*(tmax - tmin))
|
|
ax.set_ylim(0, 550)
|
|
ax.set_ylabel('Rate', 'Hz')
|
|
ax.set_yticks_delta(200)
|
|
|
|
|
|
def plot_psd(ax, s, freqs, power, fmax, dt=0.0005, nfft=512):
|
|
# plot:
|
|
mask = freqs <= fmax
|
|
freqs = freqs[mask]
|
|
power = power[mask]
|
|
ax.show_spines('lb')
|
|
ax.plot(freqs, power, **s.lsPower)
|
|
ax.set_xlim(0, fmax)
|
|
ax.set_ylim(-20, 0)
|
|
ax.set_xlabel('Frequency', 'Hz')
|
|
ax.set_ylabel('Power', 'dB')
|
|
ax.set_yticks_delta(10)
|
|
|
|
|
|
def mark_freq(ax, freqs, power, f, label, style, xoffs=10, yoffs=0, toffs=0, angle=0):
|
|
i = np.argmin(np.abs(freqs - abs(f)))
|
|
p = power[i]
|
|
f = freqs[i]
|
|
ax.plot(f, p + 1 + yoffs, clip_on=False, **style)
|
|
if label:
|
|
yoffs += 3 + toffs
|
|
if angle > 0:
|
|
yoffs -= 1
|
|
ax.text(f - xoffs, p + yoffs, label, color=style['color'], rotation=angle)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
spikes, eodf, df1, df2, t0, t1, t2, t12, index = load_spikes(data_path / cell)
|
|
print(f'Loaded spike data for cell {cell} @ index {index}:')
|
|
print(f' EODf = {eodf:.1f}Hz')
|
|
print(f' Df1 = {df1:.1f}Hz')
|
|
print(f' Df2 = {df2:.1f}Hz')
|
|
print(f' {len(spikes)} trials')
|
|
|
|
print(f'Load AMs for cell {cell} @ index {index}:')
|
|
ams = load_am(data_path / cell, index)
|
|
|
|
phases = am_phases(ams, eodf, df1, df2, t1, t2, t12)
|
|
|
|
s = plot_style()
|
|
fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.6*s.plot_width),
|
|
height_ratios=[1, 0, 2, 1.3, 3, 0.7, 5])
|
|
fig.subplots_adjust(leftm=7.5, rightm=4.5, topm=1.5, bottomm=4,
|
|
wspace=0.4, hspace=0.4)
|
|
fmax = 250
|
|
tmin = 0.106
|
|
tmax = 0.206
|
|
twins = [[-t0, 0], [t1, t1 + t2], [0, t1], [t1 + t2, t1 + t2 + t12]]
|
|
stim_freqs = [[None, None], [df2, None], [None, df1], [df1, df2]]
|
|
stim_phases = [[None, None], [phases[:, 1], None], [None, phases[:, 0]], [phases[:, 2], phases[:, 3]]]
|
|
base_rate = baseline_rate(spikes, *twins[0])
|
|
print(f'Baseline firing rate: {base_rate:.1f}Hz')
|
|
powers = []
|
|
for i in range(axs.shape[1]):
|
|
tstart, tend = twins[i]
|
|
plot_symbols(axs[0, i], s, stim_freqs[i])
|
|
plot_stimulus(axs[1, i], s, tmax - tmin, eodf, stim_freqs[i])
|
|
sub_spikes = [times[(times >= tstart) & (times <= tend)] - tstart for times in spikes]
|
|
freqs, power = power_spectrum(sub_spikes, tend - tstart)
|
|
powers.append(power)
|
|
plot_psd(axs[4, i], s, freqs, power, fmax)
|
|
sub_spikes = align_spikes(sub_spikes, stim_freqs[i], stim_phases[i])
|
|
plot_raster(axs[2, i], s, sub_spikes, tmin, tmax)
|
|
plot_rate(axs[3, i], s, sub_spikes, tmin, tmax, sigma)
|
|
mark_freq(axs[4, 0], freqs, powers[0], base_rate, f'$r={base_rate:.0f}$\\,Hz', s.psF0, 30)
|
|
mark_freq(axs[4, 1], freqs, powers[1], df2, f'$\\Delta f_1=f_1 - f_{{EOD}}={abs(df2):.0f}$\\,Hz', s.psF02)
|
|
mark_freq(axs[4, 1], freqs, powers[1], 2*df2, f'$2\\Delta f_1={abs(2*df2):.0f}$\\,Hz', s.psF02)
|
|
mark_freq(axs[4, 2], freqs, powers[2], df1, '', s.psF0)
|
|
mark_freq(axs[4, 2], freqs, powers[2], df1, f'$\\Delta f_2=f_2 - f_{{EOD}}={abs(df1):.0f}$\\,Hz',
|
|
s.psF01, 130, 1.5)
|
|
mark_freq(axs[4, 3], freqs, powers[3], df2, '', s.psF02)
|
|
mark_freq(axs[4, 3], freqs, powers[3], 2*df2, '', s.psF02)
|
|
mark_freq(axs[4, 3], freqs, powers[3], df1, '', s.psF0)
|
|
mark_freq(axs[4, 3], freqs, powers[3], df1, '', s.psF01, 130, 1.5)
|
|
mark_freq(axs[4, 3], freqs, powers[3], abs(df1) + abs(df2) - 2,
|
|
f'$\\Delta f_2 + \\Delta f_1={abs(df1) + abs(df2):.0f}$\\,Hz', s.psF012, 20, angle=40)
|
|
mark_freq(axs[4, 3], freqs, powers[3], abs(df1) - abs(df2),
|
|
f'$\\Delta f_2 - \\Delta f_1={abs(df1) - abs(df2):.0f}$\\,Hz', s.psF01_2, 50, toffs=5, angle=40)
|
|
fig.common_yticks(axs[3, :])
|
|
fig.common_yticks(axs[4, :])
|
|
axs[3, 0].xscalebar(1, 0, 20, 'ms', ha='right')
|
|
#axs[3, 0].scalebars(-0.03, 0, 20, 500, 'ms', 'Hz')
|
|
#axs[4, 0].yscalebar(-0.03, 0.5, 10, 'dB', va='center')
|
|
#fig.tag(axs.T)
|
|
fig.tag(axs[0])
|
|
fig.savefig()
|
|
#plt.show()
|
|
print()
|