finished figure 3
This commit is contained in:
313
twobeats.py
313
twobeats.py
@@ -3,6 +3,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path
|
||||
from scipy.stats import norm
|
||||
from scipy.optimize import curve_fit
|
||||
from spectral import rate
|
||||
from plotstyle import plot_style
|
||||
|
||||
@@ -12,21 +13,13 @@ cell = '2021-08-03-ac-invivo-1'
|
||||
data_path = Path('data')
|
||||
|
||||
|
||||
def load_data(cell_path, f1=797, f2=631):
|
||||
def load_spikes(cell_path, f1=797, f2=631):
|
||||
load = False
|
||||
spikes = []
|
||||
index = 0
|
||||
with open(cell_path / 'threefish-spikes.dat') as sf:
|
||||
for line in sf:
|
||||
if line.startswith('# EOD rate '):
|
||||
eodf = float(line.split(':')[1].strip().replace('Hz', ''))
|
||||
elif line.startswith('# Deltaf1 '):
|
||||
df1 = float(line.split(':')[1].strip().replace('Hz', ''))
|
||||
elif line.startswith('# Deltaf2 '):
|
||||
df2 = float(line.split(':')[1].strip().replace('Hz', ''))
|
||||
if abs(eodf + df1 - f1) < 1 and abs(eodf + df2 - f2) < 1:
|
||||
#print(f'EODf={eodf:6.1f}Hz, Df1={df1:6.1f}Hz, Df2={df2:6.1f}Hz, EODf1={eodf + df1:6.1f}Hz, EODf2={eodf + df2:6.1f}Hz')
|
||||
load = True
|
||||
elif load:
|
||||
if load:
|
||||
if ' before:' in line:
|
||||
t0 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
||||
elif ' duration1 ' in line:
|
||||
@@ -38,7 +31,7 @@ def load_data(cell_path, f1=797, f2=631):
|
||||
elif line.startswith('# index '):
|
||||
if len(spikes) > 0:
|
||||
spikes[-1] = np.array(spikes[-1])
|
||||
return spikes, eodf, df1, df2, t0, t1, t2, t12
|
||||
return spikes, eodf, df1, df2, t0, t1, t2, t12, index
|
||||
elif line.startswith('# trial:'):
|
||||
if len(spikes) > 0:
|
||||
spikes[-1] = np.array(spikes[-1])
|
||||
@@ -46,71 +39,197 @@ def load_data(cell_path, f1=797, f2=631):
|
||||
elif len(line.strip()) > 0 and line[0] != '#':
|
||||
t = 0.001*float(line.strip())
|
||||
spikes[-1].append(t)
|
||||
elif line.startswith('# index '):
|
||||
index += 1
|
||||
elif line.startswith('# EOD rate '):
|
||||
eodf = float(line.split(':')[1].strip().replace('Hz', ''))
|
||||
elif line.startswith('# Deltaf1 '):
|
||||
df1 = float(line.split(':')[1].strip().replace('Hz', ''))
|
||||
elif line.startswith('# Deltaf2 '):
|
||||
df2 = float(line.split(':')[1].strip().replace('Hz', ''))
|
||||
if abs(eodf + df1 - f1) < 1 and abs(eodf + df2 - f2) < 1:
|
||||
#print(f'EODf={eodf:6.1f}Hz, Df1={df1:6.1f}Hz, Df2={df2:6.1f}Hz, EODf1={eodf + df1:6.1f}Hz, EODf2={eodf + df2:6.1f}Hz')
|
||||
load = True
|
||||
print(f'no spikes found for EODf1={f1:.1f}Hz and EODf2={f2:.1f}Hz')
|
||||
|
||||
|
||||
def align_spikes(spikes, period):
|
||||
# compute rates for each trial:
|
||||
tmax = np.max([s[-1] for s in spikes])
|
||||
time = np.arange(0, tmax, 0.0002)
|
||||
sigma = 0.001
|
||||
kernel = norm.pdf(time[time < 8*sigma], loc=4*sigma, scale=sigma)
|
||||
rates = []
|
||||
xtime = np.append(time, time[-1] + time[1] - time[0])
|
||||
for i, spiket in enumerate(spikes):
|
||||
b, _ = np.histogram(spiket, xtime)
|
||||
r = np.convolve(b, kernel, 'same')
|
||||
rates.append(r)
|
||||
# align them on the first trial:
|
||||
nrates = len(rates[0])
|
||||
for i in range(1, len(rates)):
|
||||
rs = []
|
||||
n = len(time[time <= period])
|
||||
if n < 2:
|
||||
n = 2
|
||||
for k in range(1, 1 + n):
|
||||
r = np.corrcoef(rates[0][:-k], rates[i][k:])[0, 1]
|
||||
rs.append(r)
|
||||
k = 1 + np.argmax(rs)
|
||||
dt = time[k]
|
||||
spikes[i] -= dt
|
||||
print(f' shift trial {i} by {1000*dt:.0f}ms')
|
||||
def load_am(cell_path, inx):
|
||||
load = False
|
||||
ams = []
|
||||
index = 0
|
||||
with open(cell_path / 'threefish-ams.dat') as sf:
|
||||
for line in sf:
|
||||
if load:
|
||||
if line.startswith('# index '):
|
||||
if len(ams) > 0:
|
||||
ams[-1] = np.array(ams[-1])
|
||||
return ams
|
||||
elif line.startswith('# EOD rate '):
|
||||
print(f' EODf = {line.split(':')[1].strip()}')
|
||||
elif line.startswith('# Deltaf1 '):
|
||||
print(f' Df1 = {line.split(':')[1].strip()}')
|
||||
elif line.startswith('# Deltaf2 '):
|
||||
print(f' DF2 = {line.split(':')[1].strip()}')
|
||||
elif line.startswith('# trial:'):
|
||||
if len(ams) > 0:
|
||||
ams[-1] = np.array(ams[-1])
|
||||
ams.append([])
|
||||
elif len(line.strip()) > 0 and line[0] != '#':
|
||||
time, am = line.split()
|
||||
t = 0.001*float(time.strip())
|
||||
a = float(am.strip())
|
||||
ams[-1].append((t, a))
|
||||
elif line.startswith('# index '):
|
||||
index += 1
|
||||
if inx == index:
|
||||
load = True
|
||||
print(f'no AM found at index {inx}')
|
||||
|
||||
|
||||
def cosine(x, a, f, p, c):
|
||||
return a*np.cos(2*np.pi*f*x + p) + c
|
||||
|
||||
|
||||
def two_cosine(x, a1, f1, p1, a2, f2, p2, c):
|
||||
return a1*np.cos(2*np.pi*f1*x + p1) + a2*np.cos(2*np.pi*f2*x + p2) + c
|
||||
|
||||
|
||||
def am_phases(ams, eodf, df1, df2, t1, t2, t12):
|
||||
twins = (t1, t2, t12)
|
||||
dfs = ((df1,), (df2,), (df1, df2))
|
||||
phases = np.zeros((len(ams), len(dfs) + 1))
|
||||
for k in range(len(ams)):
|
||||
t0 = 0
|
||||
time = ams[k][:, 0]
|
||||
am = ams[k][:, 1]
|
||||
for i in range(len(twins)):
|
||||
tw = twins[0]
|
||||
t1 = t0 + tw
|
||||
mask = (time >= t0) & (time <= t1)
|
||||
tam = time[mask] - t0
|
||||
aam = am[mask]
|
||||
a = 0.5*(np.max(aam) - np.min(aam))
|
||||
c = np.mean(aam)
|
||||
tt = np.linspace(0, tw, 1000)
|
||||
if len(dfs[i]) == 2:
|
||||
popt = [a/2, dfs[i][0], 0, a/2, dfs[i][1], 0, c]
|
||||
popt, _ = curve_fit(two_cosine, tam, aam, popt)
|
||||
aa = two_cosine(tt, *popt)
|
||||
phases[k, i] = popt[2] if popt[0] > 0 else popt[2] + np.pi
|
||||
phases[k, i + 1] = popt[5] if popt[3] > 0 else popt[5] + np.pi
|
||||
else:
|
||||
popt = [a, dfs[i][0], 0, c]
|
||||
popt, _ = curve_fit(cosine, tam, aam, popt)
|
||||
aa = cosine(tt, *popt)
|
||||
phases[k, i] = popt[2] if popt[0] > 0 else popt[2] + np.pi
|
||||
t0 = t1
|
||||
return phases
|
||||
|
||||
|
||||
def align_spikes(spikes, freqs, phases):
|
||||
f1, f2 = freqs
|
||||
if f1 is None and f2 is None:
|
||||
return spikes
|
||||
p1 = phases[0]
|
||||
p2 = phases[1]
|
||||
if f2 is None:
|
||||
df = f1
|
||||
p = p1
|
||||
else:
|
||||
df = f2
|
||||
p = p2
|
||||
for i in range(len(spikes)):
|
||||
spikes[i] += p[i]/2/np.pi/df
|
||||
return spikes
|
||||
|
||||
|
||||
def plot_symbols(ax, s):
|
||||
def baseline_rate(spikes, t0, t1):
|
||||
rates = []
|
||||
for times in spikes:
|
||||
c = np.sum((times > t0) & (times < t1))
|
||||
rates.append(c/(t1 - t0))
|
||||
return np.mean(rates)
|
||||
|
||||
|
||||
def power_spectrum(spikes, tmax, dt=0.0005, nfft=512, p_ref=4000):
|
||||
time = np.arange(0, tmax, dt)
|
||||
if nfft > len(time):
|
||||
print('nfft too large:', nfft, len(time))
|
||||
freqs = np.fft.fftfreq(nfft, dt)
|
||||
freqs = np.fft.fftshift(freqs)
|
||||
segments = range(0, len(time) - nfft, nfft)
|
||||
p_rr = np.zeros(len(freqs))
|
||||
n = 0
|
||||
for i, spiket in enumerate(spikes):
|
||||
b, _ = np.histogram(spiket, time)
|
||||
b = b / dt
|
||||
for j, k in enumerate(segments):
|
||||
fourier_r = np.fft.fft(b[k:k + nfft] - np.mean(b), n=nfft)
|
||||
fourier_r = np.fft.fftshift(fourier_r)
|
||||
p_rr += np.abs(fourier_r*np.conj(fourier_r))
|
||||
n += 1
|
||||
mask = freqs >= 0.0
|
||||
freqs = freqs[mask]
|
||||
scale = dt/nfft/n
|
||||
p_rr = p_rr[mask]*scale
|
||||
power = 10*np.log10(p_rr/p_ref)
|
||||
return freqs, power
|
||||
|
||||
|
||||
def plot_symbols(ax, s, freqs):
|
||||
f1, f2 = freqs
|
||||
ax.show_spines('')
|
||||
ax.add_artist(plt.Rectangle((-1, -0.5), 2, 1, color=s.colors['black']))
|
||||
ax.harrow(1.6, 0, 1.3, **s.asLine)
|
||||
ax.set_xlim(-6, 14)
|
||||
ax.set_ylim(-1, 1)
|
||||
if f1 is None and f2 is None:
|
||||
ax.text(3.5, 0, '$r$', va='center')
|
||||
else:
|
||||
ax.harrow(-2.8, 0, 1.3, **s.asLine)
|
||||
if f2 is None:
|
||||
ax.text(-3.2, 0, '$s_1(t)$', ha='right', va='center')
|
||||
ax.text(3.3, 0, '$r + r_1(t)$', va='center')
|
||||
elif f1 is None:
|
||||
ax.text(-3.2, 0, '$s_2(t)$', ha='right', va='center')
|
||||
ax.text(3.3, 0, '$r + r_2(t)$', va='center')
|
||||
else:
|
||||
ax.text(-3.2, 0, '$s_1(t) + s_2(t)$', ha='right', va='center')
|
||||
ax.text(3.3, 0, '$\\ne r + r_1(t) + r_2(t)$', va='center')
|
||||
|
||||
|
||||
def plot_stimulus(ax, s, tmax, eodf, f1, f2, c=0.1):
|
||||
def plot_stimulus(ax, s, tmax, eodf, freqs, c=0.1):
|
||||
time = np.arange(0, tmax, 0.0001)
|
||||
eod = np.cos(2*np.pi*eodf*time)
|
||||
am = np.ones(len(time))
|
||||
ams = {}
|
||||
f1, f2 = freqs
|
||||
label = '$f_{EOD}$'
|
||||
if f1 is not None:
|
||||
eod += c*np.cos(2*np.pi*(eodf + f1)*time)
|
||||
am += c*np.cos(2*np.pi*f1*time)
|
||||
ams = s.lsF01
|
||||
ams = s.lsF02
|
||||
label += r' \& $f_1$'
|
||||
if f2 is not None:
|
||||
eod += c*np.cos(2*np.pi*(eodf + f2)*time)
|
||||
am += c*np.cos(2*np.pi*f2*time)
|
||||
ams = s.lsF02
|
||||
ams = s.lsF01
|
||||
label += r' \& $f_2$'
|
||||
if f1 is not None and f2 is not None:
|
||||
ams = s.lsF012
|
||||
ams = s.lsF01_2
|
||||
ax.show_spines('')
|
||||
ax.plot(1000*time, am*eod, **s.lsStim)
|
||||
ax.plot(1000*time, eod, **s.lsEOD)
|
||||
if len(ams) > 0:
|
||||
ax.plot(1000*time, am, **ams)
|
||||
ax.set_xlim(0, 1000*tmax)
|
||||
ax.set_ylim(-1.02 - 2*c, 1.02 + 2*c)
|
||||
ax.text(0, 1.1, label, transform=ax.transAxes)
|
||||
ax.text(0.5, 1.2, label, ha='center', transform=ax.transAxes)
|
||||
|
||||
|
||||
def plot_raster(ax, s, spikes, tmin, tmax):
|
||||
spikes_ms = [1000*(s[(s > tmin) & (s < tmax)] - tmin) for s in spikes]
|
||||
ax.show_spines('')
|
||||
ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
|
||||
ax.eventplot(spikes_ms, linelengths=0.8, **s.lsRaster)
|
||||
ax.set_xlim(0, 1000*(tmax - tmin))
|
||||
|
||||
|
||||
@@ -123,71 +242,89 @@ def plot_rate(ax, s, spikes, tmin, tmax, sigma=0.002):
|
||||
ax.show_spines('')
|
||||
ax.plot(1000*time, r, **s.lsRate)
|
||||
ax.set_xlim(0, 1000*(tmax - tmin))
|
||||
ax.set_ylim(0, 500)
|
||||
ax.set_ylim(-10, 550)
|
||||
|
||||
|
||||
def plot_psd(ax, s, spikes, tmax, fmax, dt=0.0005, nfft=512):
|
||||
time = np.arange(0, tmax, dt)
|
||||
if nfft > len(time):
|
||||
print('nfft too large:', nfft, len(time))
|
||||
# power spectrum:
|
||||
freqs = np.fft.fftfreq(nfft, dt)
|
||||
freqs = np.fft.fftshift(freqs)
|
||||
f0 = len(freqs)//4
|
||||
f1 = 3*len(freqs)//4
|
||||
segments = range(0, len(time) - nfft, nfft)
|
||||
p_rr = np.zeros(len(freqs))
|
||||
n = 0
|
||||
for i, spiket in enumerate(spikes):
|
||||
b, _ = np.histogram(spiket, time)
|
||||
b = b / dt
|
||||
for j, k in enumerate(segments):
|
||||
fourier_r = np.fft.fft(b[k:k + nfft] - np.mean(b), n=nfft)
|
||||
fourier_r = np.fft.fftshift(fourier_r)
|
||||
p_rr += np.abs(fourier_r*np.conj(fourier_r))
|
||||
n += 1
|
||||
freqs = freqs[f0:f1]
|
||||
scale = dt/nfft/n
|
||||
p_rr = p_rr[f0:f1]*scale
|
||||
def plot_psd(ax, s, freqs, power, fmax, dt=0.0005, nfft=512):
|
||||
# plot:
|
||||
mask = (freqs > 0) & (freqs <= fmax)
|
||||
mask = freqs <= fmax
|
||||
freqs = freqs[mask]
|
||||
p_rr = p_rr[mask]
|
||||
#print(np.max(p_rr))
|
||||
p_ref = 4000
|
||||
ax.plot(freqs, 10*np.log10(p_rr/p_ref), **s.lsPower)
|
||||
power = power[mask]
|
||||
ax.show_spines('b')
|
||||
ax.plot(freqs, power, **s.lsPower)
|
||||
ax.set_xlim(0, fmax)
|
||||
ax.set_ylim(-20, 0)
|
||||
ax.set_xlabel('Frequency', 'Hz')
|
||||
|
||||
|
||||
def mark_freq(ax, freqs, power, f, label, style, xoffs=10, yoffs=0, toffs=0, angle=0):
|
||||
i = np.argmin(np.abs(freqs - abs(f)))
|
||||
p = power[i]
|
||||
f = freqs[i]
|
||||
ax.plot(f, p + 1 + yoffs, clip_on=False, **style)
|
||||
if label:
|
||||
yoffs += 3 + toffs
|
||||
if angle > 0:
|
||||
yoffs -= 1
|
||||
ax.text(f - xoffs, p + yoffs, label, color=style['color'], rotation=angle)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
spikes, eodf, df1, df2, t0, t1, t2, t12 = load_data(data_path / cell)
|
||||
print(f'Loaded spike data for cell {cell}: ')
|
||||
spikes, eodf, df1, df2, t0, t1, t2, t12, index = load_spikes(data_path / cell)
|
||||
print(f'Loaded spike data for cell {cell} @ index {index}:')
|
||||
print(f' EODf = {eodf:.1f}Hz')
|
||||
print(f' Df1 = {df1:.1f}Hz')
|
||||
print(f' Df2 = {df2:.1f}Hz')
|
||||
print(f' {len(spikes)} trials')
|
||||
|
||||
print(f'Load AMs for cell {cell} @ index {index}:')
|
||||
ams = load_am(data_path / cell, index)
|
||||
|
||||
phases = am_phases(ams, eodf, df1, df2, t1, t2, t12)
|
||||
|
||||
s = plot_style()
|
||||
fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.6*s.plot_width),
|
||||
height_ratios=[1, 2, 1, 3, 6])
|
||||
height_ratios=[1, 0, 2, 1.3, 3, 0.7, 5])
|
||||
fig.subplots_adjust(leftm=3, rightm=4.5, topm=1.5, bottomm=4, wspace=0.4, hspace=0.4)
|
||||
fmax = 250
|
||||
tmin = 0.1
|
||||
tmax = 0.2
|
||||
tmin = 0.106
|
||||
tmax = 0.206
|
||||
twins = [[-t0, 0], [t1, t1 + t2], [0, t1], [t1 + t2, t1 + t2 + t12]]
|
||||
freqs = [eodf, df2, df1, df2]
|
||||
stim_freqs = [[None, None], [df2, None], [None, df1], [df1, df2]]
|
||||
stim_phases = [[None, None], [phases[:, 1], None], [None, phases[:, 0]], [phases[:, 2], phases[:, 3]]]
|
||||
base_rate = baseline_rate(spikes, *twins[0])
|
||||
print(f'Baseline firing rate: {base_rate:.1f}Hz')
|
||||
powers = []
|
||||
for i in range(axs.shape[1]):
|
||||
tstart, tend = twins[i]
|
||||
plot_symbols(axs[0, i], s)
|
||||
plot_stimulus(axs[1, i], s, tmax - tmin, eodf, *stim_freqs[i])
|
||||
plot_symbols(axs[0, i], s, stim_freqs[i])
|
||||
plot_stimulus(axs[1, i], s, tmax - tmin, eodf, stim_freqs[i])
|
||||
sub_spikes = [times[(times >= tstart) & (times <= tend)] - tstart for times in spikes]
|
||||
plot_psd(axs[4, i], s, sub_spikes, tend - tstart, fmax)
|
||||
print(f'align spikes for frequency {freqs[i]:.0f}Hz:')
|
||||
sub_spikes = align_spikes(sub_spikes, abs(1/freqs[i]))
|
||||
freqs, power = power_spectrum(sub_spikes, tend - tstart)
|
||||
powers.append(power)
|
||||
plot_psd(axs[4, i], s, freqs, power, fmax)
|
||||
sub_spikes = align_spikes(sub_spikes, stim_freqs[i], stim_phases[i])
|
||||
plot_raster(axs[2, i], s, sub_spikes, tmin, tmax)
|
||||
plot_rate(axs[3, i], s, sub_spikes, tmin, tmax)
|
||||
#fig.savefig()
|
||||
plt.show()
|
||||
mark_freq(axs[4, 0], freqs, powers[0], base_rate, f'$r={base_rate:.0f}$\\,Hz', s.psF0, 30)
|
||||
mark_freq(axs[4, 1], freqs, powers[1], df2, f'$\\Delta f_1=f_1 - f_{{EOD}}={abs(df2):.0f}$\\,Hz', s.psF02)
|
||||
mark_freq(axs[4, 1], freqs, powers[1], 2*df2, f'$2\\Delta f_1={abs(2*df2):.0f}$\\,Hz', s.psF02)
|
||||
mark_freq(axs[4, 2], freqs, powers[2], df1, '', s.psF0)
|
||||
mark_freq(axs[4, 2], freqs, powers[2], df1, f'$\\Delta f_2=f_2 - f_{{EOD}}={abs(df1):.0f}$\\,Hz',
|
||||
s.psF01, 130, 1.5)
|
||||
mark_freq(axs[4, 3], freqs, powers[3], df2, '', s.psF02)
|
||||
mark_freq(axs[4, 3], freqs, powers[3], 2*df2, '', s.psF02)
|
||||
mark_freq(axs[4, 3], freqs, powers[3], df1, '', s.psF0)
|
||||
mark_freq(axs[4, 3], freqs, powers[3], df1, '', s.psF01, 130, 1.5)
|
||||
mark_freq(axs[4, 3], freqs, powers[3], abs(df1) + abs(df2) - 2,
|
||||
f'$\\Delta f_1 + \\Delta f_2={abs(df1) + abs(df2):.0f}$\\,Hz', s.psF012, 20, angle=40)
|
||||
mark_freq(axs[4, 3], freqs, powers[3], abs(df1) - abs(df2),
|
||||
f'$\\Delta f_1 + \\Delta f_2={abs(df1) - abs(df2):.0f}$\\,Hz', s.psF01_2, 50, toffs=5, angle=40)
|
||||
axs[3, 0].scalebars(-0.03, 0, 20, 500, 'ms', 'Hz')
|
||||
axs[4, 0].yscalebar(-0.03, 0.5, 10, 'dB', va='center')
|
||||
#fig.tag(axs.T)
|
||||
fig.tag(axs[0])
|
||||
fig.savefig()
|
||||
#plt.show()
|
||||
print()
|
||||
|
||||
Reference in New Issue
Block a user