diff --git a/plotstyle.py b/plotstyle.py index ab1c1a8..dbafeac 100644 --- a/plotstyle.py +++ b/plotstyle.py @@ -125,7 +125,7 @@ def plot_style(): pt.make_line_styles(ns, 'ls', 'AMsplit', '', palette['orange'], '-', lwmid) pt.make_line_styles(ns, 'ls', 'Noise', '', palette['gray'], '-', lwmid) pt.make_line_styles(ns, 'ls', 'Median', '', palette['black'], '-', lwthick) - pt.make_line_styles(ns, 'ls', 'Max', '', palette['black'], '-', lwmid) + pt.make_line_styles(ns, 'ls', 'Max', '', palette['black'], '--', lwthin) ns.psC1 = dict(color=palette['red'], marker='o', linestyle='none', markersize=3, mec='none', mew=0) ns.psC3 = dict(color=palette['orange'], marker='o', linestyle='none', markersize=3, mec='none', mew=0) @@ -135,6 +135,7 @@ def plot_style(): ns.lsStim = dict(color=palette['gray'], lw=ns.lwmid) ns.lsRaster = dict(color=palette['black'], lw=ns.lwthin) + ns.lsRate = dict(color=palette['blue'], lw=ns.lwmid) ns.lsPower = dict(color=palette['gray'], lw=ns.lwmid) ns.lsF0 = dict(color='blue', lw=ns.lwthick) ns.lsF01 = dict(color='green', lw=ns.lwthick) diff --git a/twobeats.py b/twobeats.py new file mode 100644 index 0000000..95bfc80 --- /dev/null +++ b/twobeats.py @@ -0,0 +1,162 @@ +import numpy as np +import matplotlib.pyplot as plt + +from pathlib import Path +from scipy.stats import norm +from spectral import rate +from plotstyle import plot_style + + +cell = '2021-08-03-ac-invivo-1' + +data_path = Path('data') + + +def load_data(cell_path, f1=797, f2=631): + load = False + spikes = [] + with open(cell_path / 'threefish-spikes.dat') as sf: + for line in sf: + if line.startswith('# EOD rate '): + eodf = float(line.split(':')[1].strip().replace('Hz', '')) + elif line.startswith('# Deltaf1 '): + df1 = float(line.split(':')[1].strip().replace('Hz', '')) + elif line.startswith('# Deltaf2 '): + df2 = float(line.split(':')[1].strip().replace('Hz', '')) + if abs(eodf + df1 - f1) < 1 and abs(eodf + df2 - f2) < 1: + #print(f'EODf={eodf:6.1f}Hz, Df1={df1:6.1f}Hz, Df2={df2:6.1f}Hz, EODf1={eodf + df1:6.1f}Hz, EODf2={eodf + df2:6.1f}Hz') + load = True + elif load: + if ' before:' in line: + t0 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) + elif ' duration1 ' in line: + t1 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) + elif ' duration2 ' in line: + t2 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) + elif ' duration12 ' in line: + t12 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) + elif line.startswith('# index '): + if len(spikes) > 0: + spikes[-1] = np.array(spikes[-1]) + return spikes, eodf, df1, df2, t0, t1, t2, t12 + elif line.startswith('# trial:'): + if len(spikes) > 0: + spikes[-1] = np.array(spikes[-1]) + spikes.append([]) + elif len(line.strip()) > 0 and line[0] != '#': + t = 0.001*float(line.strip()) + spikes[-1].append(t) + print(f'no spikes found for EODf1={f1:.1f}Hz and EODf2={f2:.1f}Hz') + + +def align_spikes(spikes, period): + # compute rates for each trial: + tmax = np.max([s[-1] for s in spikes]) + time = np.arange(0, tmax, 0.0002) + sigma = 0.001 + kernel = norm.pdf(time[time < 8*sigma], loc=4*sigma, scale=sigma) + rates = [] + xtime = np.append(time, time[-1] + time[1] - time[0]) + for i, spiket in enumerate(spikes): + b, _ = np.histogram(spiket, xtime) + r = np.convolve(b, kernel, 'same') + rates.append(r) + # align them on the first trial: + nrates = len(rates[0]) + for i in range(1, len(rates)): + rs = [] + n = len(time[time <= period]) + if n < 2: + n = 2 + for k in range(1, 1 + n): + r = np.corrcoef(rates[0][:-k], rates[i][k:])[0, 1] + rs.append(r) + k = 1 + np.argmax(rs) + dt = time[k] + spikes[i] -= dt + print(f' shift trial {i} by {1000*dt:.0f}ms') + return spikes + + +def plot_raster(ax, s, spikes, tmin, tmax): + spikes_ms = [1000*(s[(s > tmin) & (s < tmax)] - tmin) for s in spikes] + ax.show_spines('') + ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster) + ax.set_xlim(0, 1000*(tmax - tmin)) + + +def plot_rate(ax, s, spikes, tmin, tmax, sigma=0.002): + time = np.arange(0, tmin + tmax, 0.001) + r, rsd = rate(time, spikes, sigma) + mask = (time >= tmin) & (time <= tmax) + time = time[mask] - tmin + r = r[mask] + ax.show_spines('') + ax.plot(1000*time, r, **s.lsRate) + ax.set_xlim(0, 1000*(tmax - tmin)) + ax.set_ylim(0, 500) + + +def plot_psd(ax, s, spikes, tmax, fmax, dt=0.0005, nfft=512): + time = np.arange(0, tmax, dt) + if nfft > len(time): + print('nfft too large:', nfft, len(time)) + # power spectrum: + freqs = np.fft.fftfreq(nfft, dt) + freqs = np.fft.fftshift(freqs) + f0 = len(freqs)//4 + f1 = 3*len(freqs)//4 + segments = range(0, len(time) - nfft, nfft) + p_rr = np.zeros(len(freqs)) + n = 0 + for i, spiket in enumerate(spikes): + b, _ = np.histogram(spiket, time) + b = b / dt + for j, k in enumerate(segments): + fourier_r = np.fft.fft(b[k:k + nfft] - np.mean(b), n=nfft) + fourier_r = np.fft.fftshift(fourier_r) + p_rr += np.abs(fourier_r*np.conj(fourier_r)) + n += 1 + freqs = freqs[f0:f1] + scale = dt/nfft/n + p_rr = p_rr[f0:f1]*scale + # plot: + mask = (freqs > 0) & (freqs <= fmax) + freqs = freqs[mask] + p_rr = p_rr[mask] + #print(np.max(p_rr)) + p_ref = 4000 + ax.plot(freqs, 10*np.log10(p_rr/p_ref), **s.lsPower) + ax.set_xlim(0, fmax) + ax.set_ylim(-20, 0) + + +if __name__ == '__main__': + + spikes, eodf, df1, df2, t0, t1, t2, t12 = load_data(data_path / cell) + print(f'Loaded spike data for cell {cell}: ') + print(f' EODf = {eodf:.1f}Hz') + print(f' Df1 = {df1:.1f}Hz') + print(f' Df2 = {df2:.1f}Hz') + print(f' {len(spikes)} trials') + + + s = plot_style() + fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.6*s.plot_width), + height_ratios=[1, 2, 1, 3, 6]) + fmax = 250 + tmin = 0.1 + tmax = 0.2 + twins = [[-t0, 0], [t1, t1 + t2], [0, t1], [t1 + t2, t1 + t2 + t12]] + freqs = [eodf, df2, df1, df2] + for i in range(axs.shape[1]): + tstart, tend = twins[i] + sub_spikes = [times[(times >= tstart) & (times <= tend)] - tstart for times in spikes] + print(f'align spikes for frequency {freqs[i]:.0f}Hz:') + sub_spikes = align_spikes(sub_spikes, abs(1/freqs[i])) + plot_raster(axs[2, i], s, sub_spikes, tmin, tmax) + plot_rate(axs[3, i], s, sub_spikes, tmin, tmax) + plot_psd(axs[4, i], s, sub_spikes, tend - tstart, fmax) + #fig.savefig() + plt.show() + print()