163 lines
5.7 KiB
Python
163 lines
5.7 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
from pathlib import Path
|
|
from scipy.stats import norm
|
|
from spectral import rate
|
|
from plotstyle import plot_style
|
|
|
|
|
|
cell = '2021-08-03-ac-invivo-1'
|
|
|
|
data_path = Path('data')
|
|
|
|
|
|
def load_data(cell_path, f1=797, f2=631):
|
|
load = False
|
|
spikes = []
|
|
with open(cell_path / 'threefish-spikes.dat') as sf:
|
|
for line in sf:
|
|
if line.startswith('# EOD rate '):
|
|
eodf = float(line.split(':')[1].strip().replace('Hz', ''))
|
|
elif line.startswith('# Deltaf1 '):
|
|
df1 = float(line.split(':')[1].strip().replace('Hz', ''))
|
|
elif line.startswith('# Deltaf2 '):
|
|
df2 = float(line.split(':')[1].strip().replace('Hz', ''))
|
|
if abs(eodf + df1 - f1) < 1 and abs(eodf + df2 - f2) < 1:
|
|
#print(f'EODf={eodf:6.1f}Hz, Df1={df1:6.1f}Hz, Df2={df2:6.1f}Hz, EODf1={eodf + df1:6.1f}Hz, EODf2={eodf + df2:6.1f}Hz')
|
|
load = True
|
|
elif load:
|
|
if ' before:' in line:
|
|
t0 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
|
elif ' duration1 ' in line:
|
|
t1 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
|
elif ' duration2 ' in line:
|
|
t2 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
|
elif ' duration12 ' in line:
|
|
t12 = 0.001*float(line.split(':')[1].strip().replace('ms', ''))
|
|
elif line.startswith('# index '):
|
|
if len(spikes) > 0:
|
|
spikes[-1] = np.array(spikes[-1])
|
|
return spikes, eodf, df1, df2, t0, t1, t2, t12
|
|
elif line.startswith('# trial:'):
|
|
if len(spikes) > 0:
|
|
spikes[-1] = np.array(spikes[-1])
|
|
spikes.append([])
|
|
elif len(line.strip()) > 0 and line[0] != '#':
|
|
t = 0.001*float(line.strip())
|
|
spikes[-1].append(t)
|
|
print(f'no spikes found for EODf1={f1:.1f}Hz and EODf2={f2:.1f}Hz')
|
|
|
|
|
|
def align_spikes(spikes, period):
|
|
# compute rates for each trial:
|
|
tmax = np.max([s[-1] for s in spikes])
|
|
time = np.arange(0, tmax, 0.0002)
|
|
sigma = 0.001
|
|
kernel = norm.pdf(time[time < 8*sigma], loc=4*sigma, scale=sigma)
|
|
rates = []
|
|
xtime = np.append(time, time[-1] + time[1] - time[0])
|
|
for i, spiket in enumerate(spikes):
|
|
b, _ = np.histogram(spiket, xtime)
|
|
r = np.convolve(b, kernel, 'same')
|
|
rates.append(r)
|
|
# align them on the first trial:
|
|
nrates = len(rates[0])
|
|
for i in range(1, len(rates)):
|
|
rs = []
|
|
n = len(time[time <= period])
|
|
if n < 2:
|
|
n = 2
|
|
for k in range(1, 1 + n):
|
|
r = np.corrcoef(rates[0][:-k], rates[i][k:])[0, 1]
|
|
rs.append(r)
|
|
k = 1 + np.argmax(rs)
|
|
dt = time[k]
|
|
spikes[i] -= dt
|
|
print(f' shift trial {i} by {1000*dt:.0f}ms')
|
|
return spikes
|
|
|
|
|
|
def plot_raster(ax, s, spikes, tmin, tmax):
|
|
spikes_ms = [1000*(s[(s > tmin) & (s < tmax)] - tmin) for s in spikes]
|
|
ax.show_spines('')
|
|
ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
|
|
ax.set_xlim(0, 1000*(tmax - tmin))
|
|
|
|
|
|
def plot_rate(ax, s, spikes, tmin, tmax, sigma=0.002):
|
|
time = np.arange(0, tmin + tmax, 0.001)
|
|
r, rsd = rate(time, spikes, sigma)
|
|
mask = (time >= tmin) & (time <= tmax)
|
|
time = time[mask] - tmin
|
|
r = r[mask]
|
|
ax.show_spines('')
|
|
ax.plot(1000*time, r, **s.lsRate)
|
|
ax.set_xlim(0, 1000*(tmax - tmin))
|
|
ax.set_ylim(0, 500)
|
|
|
|
|
|
def plot_psd(ax, s, spikes, tmax, fmax, dt=0.0005, nfft=512):
|
|
time = np.arange(0, tmax, dt)
|
|
if nfft > len(time):
|
|
print('nfft too large:', nfft, len(time))
|
|
# power spectrum:
|
|
freqs = np.fft.fftfreq(nfft, dt)
|
|
freqs = np.fft.fftshift(freqs)
|
|
f0 = len(freqs)//4
|
|
f1 = 3*len(freqs)//4
|
|
segments = range(0, len(time) - nfft, nfft)
|
|
p_rr = np.zeros(len(freqs))
|
|
n = 0
|
|
for i, spiket in enumerate(spikes):
|
|
b, _ = np.histogram(spiket, time)
|
|
b = b / dt
|
|
for j, k in enumerate(segments):
|
|
fourier_r = np.fft.fft(b[k:k + nfft] - np.mean(b), n=nfft)
|
|
fourier_r = np.fft.fftshift(fourier_r)
|
|
p_rr += np.abs(fourier_r*np.conj(fourier_r))
|
|
n += 1
|
|
freqs = freqs[f0:f1]
|
|
scale = dt/nfft/n
|
|
p_rr = p_rr[f0:f1]*scale
|
|
# plot:
|
|
mask = (freqs > 0) & (freqs <= fmax)
|
|
freqs = freqs[mask]
|
|
p_rr = p_rr[mask]
|
|
#print(np.max(p_rr))
|
|
p_ref = 4000
|
|
ax.plot(freqs, 10*np.log10(p_rr/p_ref), **s.lsPower)
|
|
ax.set_xlim(0, fmax)
|
|
ax.set_ylim(-20, 0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
spikes, eodf, df1, df2, t0, t1, t2, t12 = load_data(data_path / cell)
|
|
print(f'Loaded spike data for cell {cell}: ')
|
|
print(f' EODf = {eodf:.1f}Hz')
|
|
print(f' Df1 = {df1:.1f}Hz')
|
|
print(f' Df2 = {df2:.1f}Hz')
|
|
print(f' {len(spikes)} trials')
|
|
|
|
|
|
s = plot_style()
|
|
fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.6*s.plot_width),
|
|
height_ratios=[1, 2, 1, 3, 6])
|
|
fmax = 250
|
|
tmin = 0.1
|
|
tmax = 0.2
|
|
twins = [[-t0, 0], [t1, t1 + t2], [0, t1], [t1 + t2, t1 + t2 + t12]]
|
|
freqs = [eodf, df2, df1, df2]
|
|
for i in range(axs.shape[1]):
|
|
tstart, tend = twins[i]
|
|
sub_spikes = [times[(times >= tstart) & (times <= tend)] - tstart for times in spikes]
|
|
print(f'align spikes for frequency {freqs[i]:.0f}Hz:')
|
|
sub_spikes = align_spikes(sub_spikes, abs(1/freqs[i]))
|
|
plot_raster(axs[2, i], s, sub_spikes, tmin, tmax)
|
|
plot_rate(axs[3, i], s, sub_spikes, tmin, tmax)
|
|
plot_psd(axs[4, i], s, sub_spikes, tend - tstart, fmax)
|
|
#fig.savefig()
|
|
plt.show()
|
|
print()
|