started to work on two beats figure

This commit is contained in:
Jan Benda 2026-01-26 19:01:59 +01:00
parent 3b13588595
commit 34fdd9872c
2 changed files with 164 additions and 1 deletions

View File

@ -125,7 +125,7 @@ def plot_style():
pt.make_line_styles(ns, 'ls', 'AMsplit', '', palette['orange'], '-', lwmid) pt.make_line_styles(ns, 'ls', 'AMsplit', '', palette['orange'], '-', lwmid)
pt.make_line_styles(ns, 'ls', 'Noise', '', palette['gray'], '-', lwmid) pt.make_line_styles(ns, 'ls', 'Noise', '', palette['gray'], '-', lwmid)
pt.make_line_styles(ns, 'ls', 'Median', '', palette['black'], '-', lwthick) pt.make_line_styles(ns, 'ls', 'Median', '', palette['black'], '-', lwthick)
pt.make_line_styles(ns, 'ls', 'Max', '', palette['black'], '-', lwmid) pt.make_line_styles(ns, 'ls', 'Max', '', palette['black'], '--', lwthin)
ns.psC1 = dict(color=palette['red'], marker='o', linestyle='none', markersize=3, mec='none', mew=0) ns.psC1 = dict(color=palette['red'], marker='o', linestyle='none', markersize=3, mec='none', mew=0)
ns.psC3 = dict(color=palette['orange'], marker='o', linestyle='none', markersize=3, mec='none', mew=0) ns.psC3 = dict(color=palette['orange'], marker='o', linestyle='none', markersize=3, mec='none', mew=0)
@ -135,6 +135,7 @@ def plot_style():
ns.lsStim = dict(color=palette['gray'], lw=ns.lwmid) ns.lsStim = dict(color=palette['gray'], lw=ns.lwmid)
ns.lsRaster = dict(color=palette['black'], lw=ns.lwthin) ns.lsRaster = dict(color=palette['black'], lw=ns.lwthin)
ns.lsRate = dict(color=palette['blue'], lw=ns.lwmid)
ns.lsPower = dict(color=palette['gray'], lw=ns.lwmid) ns.lsPower = dict(color=palette['gray'], lw=ns.lwmid)
ns.lsF0 = dict(color='blue', lw=ns.lwthick) ns.lsF0 = dict(color='blue', lw=ns.lwthick)
ns.lsF01 = dict(color='green', lw=ns.lwthick) ns.lsF01 = dict(color='green', lw=ns.lwthick)

162
twobeats.py Normal file
View File

@ -0,0 +1,162 @@
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.stats import norm
from spectral import rate
from plotstyle import plot_style
cell = '2021-08-03-ac-invivo-1'
data_path = Path('data')
def load_data(cell_path, f1=797, f2=631):
load = False
spikes = []
with open(cell_path / 'threefish-spikes.dat') as sf:
for line in sf:
if line.startswith('# EOD rate '):
eodf = float(line.split(':')[1].strip().replace('Hz', ''))
elif line.startswith('# Deltaf1 '):
df1 = float(line.split(':')[1].strip().replace('Hz', ''))
elif line.startswith('# Deltaf2 '):
df2 = float(line.split(':')[1].strip().replace('Hz', ''))
if abs(eodf + df1 - f1) < 1 and abs(eodf + df2 - f2) < 1:
#print(f'EODf={eodf:6.1f}Hz, Df1={df1:6.1f}Hz, Df2={df2:6.1f}Hz, EODf1={eodf + df1:6.1f}Hz, EODf2={eodf + df2:6.1f}Hz')
load = True
elif load:
if ' before:' in line:
t0 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
elif ' duration1 ' in line:
t1 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
elif ' duration2 ' in line:
t2 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
elif ' duration12 ' in line:
t12 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
elif line.startswith('# index '):
if len(spikes) > 0:
spikes[-1] = np.array(spikes[-1])
return spikes, eodf, df1, df2, t0, t1, t2, t12
elif line.startswith('# trial:'):
if len(spikes) > 0:
spikes[-1] = np.array(spikes[-1])
spikes.append([])
elif len(line.strip()) > 0 and line[0] != '#':
t = 0.001*float(line.strip())
spikes[-1].append(t)
print(f'no spikes found for EODf1={f1:.1f}Hz and EODf2={f2:.1f}Hz')
def align_spikes(spikes, period):
# compute rates for each trial:
tmax = np.max([s[-1] for s in spikes])
time = np.arange(0, tmax, 0.0002)
sigma = 0.001
kernel = norm.pdf(time[time < 8*sigma], loc=4*sigma, scale=sigma)
rates = []
xtime = np.append(time, time[-1] + time[1] - time[0])
for i, spiket in enumerate(spikes):
b, _ = np.histogram(spiket, xtime)
r = np.convolve(b, kernel, 'same')
rates.append(r)
# align them on the first trial:
nrates = len(rates[0])
for i in range(1, len(rates)):
rs = []
n = len(time[time <= period])
if n < 2:
n = 2
for k in range(1, 1 + n):
r = np.corrcoef(rates[0][:-k], rates[i][k:])[0, 1]
rs.append(r)
k = 1 + np.argmax(rs)
dt = time[k]
spikes[i] -= dt
print(f' shift trial {i} by {1000*dt:.0f}ms')
return spikes
def plot_raster(ax, s, spikes, tmin, tmax):
spikes_ms = [1000*(s[(s > tmin) & (s < tmax)] - tmin) for s in spikes]
ax.show_spines('')
ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
ax.set_xlim(0, 1000*(tmax - tmin))
def plot_rate(ax, s, spikes, tmin, tmax, sigma=0.002):
time = np.arange(0, tmin + tmax, 0.001)
r, rsd = rate(time, spikes, sigma)
mask = (time >= tmin) & (time <= tmax)
time = time[mask] - tmin
r = r[mask]
ax.show_spines('')
ax.plot(1000*time, r, **s.lsRate)
ax.set_xlim(0, 1000*(tmax - tmin))
ax.set_ylim(0, 500)
def plot_psd(ax, s, spikes, tmax, fmax, dt=0.0005, nfft=512):
time = np.arange(0, tmax, dt)
if nfft > len(time):
print('nfft too large:', nfft, len(time))
# power spectrum:
freqs = np.fft.fftfreq(nfft, dt)
freqs = np.fft.fftshift(freqs)
f0 = len(freqs)//4
f1 = 3*len(freqs)//4
segments = range(0, len(time) - nfft, nfft)
p_rr = np.zeros(len(freqs))
n = 0
for i, spiket in enumerate(spikes):
b, _ = np.histogram(spiket, time)
b = b / dt
for j, k in enumerate(segments):
fourier_r = np.fft.fft(b[k:k + nfft] - np.mean(b), n=nfft)
fourier_r = np.fft.fftshift(fourier_r)
p_rr += np.abs(fourier_r*np.conj(fourier_r))
n += 1
freqs = freqs[f0:f1]
scale = dt/nfft/n
p_rr = p_rr[f0:f1]*scale
# plot:
mask = (freqs > 0) & (freqs <= fmax)
freqs = freqs[mask]
p_rr = p_rr[mask]
#print(np.max(p_rr))
p_ref = 4000
ax.plot(freqs, 10*np.log10(p_rr/p_ref), **s.lsPower)
ax.set_xlim(0, fmax)
ax.set_ylim(-20, 0)
if __name__ == '__main__':
spikes, eodf, df1, df2, t0, t1, t2, t12 = load_data(data_path / cell)
print(f'Loaded spike data for cell {cell}: ')
print(f' EODf = {eodf:.1f}Hz')
print(f' Df1 = {df1:.1f}Hz')
print(f' Df2 = {df2:.1f}Hz')
print(f' {len(spikes)} trials')
s = plot_style()
fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.6*s.plot_width),
height_ratios=[1, 2, 1, 3, 6])
fmax = 250
tmin = 0.1
tmax = 0.2
twins = [[-t0, 0], [t1, t1 + t2], [0, t1], [t1 + t2, t1 + t2 + t12]]
freqs = [eodf, df2, df1, df2]
for i in range(axs.shape[1]):
tstart, tend = twins[i]
sub_spikes = [times[(times >= tstart) & (times <= tend)] - tstart for times in spikes]
print(f'align spikes for frequency {freqs[i]:.0f}Hz:')
sub_spikes = align_spikes(sub_spikes, abs(1/freqs[i]))
plot_raster(axs[2, i], s, sub_spikes, tmin, tmax)
plot_rate(axs[3, i], s, sub_spikes, tmin, tmax)
plot_psd(axs[4, i], s, sub_spikes, tend - tstart, fmax)
#fig.savefig()
plt.show()
print()