import numpy as np import matplotlib.pyplot as plt from pathlib import Path from scipy.stats import norm from spectral import rate from plotstyle import plot_style cell = '2021-08-03-ac-invivo-1' data_path = Path('data') def load_data(cell_path, f1=797, f2=631): load = False spikes = [] with open(cell_path / 'threefish-spikes.dat') as sf: for line in sf: if line.startswith('# EOD rate '): eodf = float(line.split(':')[1].strip().replace('Hz', '')) elif line.startswith('# Deltaf1 '): df1 = float(line.split(':')[1].strip().replace('Hz', '')) elif line.startswith('# Deltaf2 '): df2 = float(line.split(':')[1].strip().replace('Hz', '')) if abs(eodf + df1 - f1) < 1 and abs(eodf + df2 - f2) < 1: #print(f'EODf={eodf:6.1f}Hz, Df1={df1:6.1f}Hz, Df2={df2:6.1f}Hz, EODf1={eodf + df1:6.1f}Hz, EODf2={eodf + df2:6.1f}Hz') load = True elif load: if ' before:' in line: t0 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) elif ' duration1 ' in line: t1 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) elif ' duration2 ' in line: t2 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) elif ' duration12 ' in line: t12 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) elif line.startswith('# index '): if len(spikes) > 0: spikes[-1] = np.array(spikes[-1]) return spikes, eodf, df1, df2, t0, t1, t2, t12 elif line.startswith('# trial:'): if len(spikes) > 0: spikes[-1] = np.array(spikes[-1]) spikes.append([]) elif len(line.strip()) > 0 and line[0] != '#': t = 0.001*float(line.strip()) spikes[-1].append(t) print(f'no spikes found for EODf1={f1:.1f}Hz and EODf2={f2:.1f}Hz') def align_spikes(spikes, period): # compute rates for each trial: tmax = np.max([s[-1] for s in spikes]) time = np.arange(0, tmax, 0.0002) sigma = 0.001 kernel = norm.pdf(time[time < 8*sigma], loc=4*sigma, scale=sigma) rates = [] xtime = np.append(time, time[-1] + time[1] - time[0]) for i, spiket in enumerate(spikes): b, _ = np.histogram(spiket, xtime) r = np.convolve(b, kernel, 'same') rates.append(r) # align them on the first trial: nrates = len(rates[0]) for i in range(1, len(rates)): rs = [] n = len(time[time <= period]) if n < 2: n = 2 for k in range(1, 1 + n): r = np.corrcoef(rates[0][:-k], rates[i][k:])[0, 1] rs.append(r) k = 1 + np.argmax(rs) dt = time[k] spikes[i] -= dt print(f' shift trial {i} by {1000*dt:.0f}ms') return spikes def plot_symbols(ax, s): ax.show_spines('') def plot_stimulus(ax, s, tmax, eodf, f1, f2, c=0.1): time = np.arange(0, tmax, 0.0001) eod = np.cos(2*np.pi*eodf*time) am = np.ones(len(time)) ams = {} label = '$f_{EOD}$' if f1 is not None: am += c*np.cos(2*np.pi*f1*time) ams = s.lsF01 label += r' \& $f_1$' if f2 is not None: am += c*np.cos(2*np.pi*f2*time) ams = s.lsF02 label += r' \& $f_2$' if f1 is not None and f2 is not None: ams = s.lsF012 ax.show_spines('') ax.plot(1000*time, am*eod, **s.lsStim) if len(ams) > 0: ax.plot(1000*time, am, **ams) ax.set_xlim(0, 1000*tmax) ax.set_ylim(-1.02 - 2*c, 1.02 + 2*c) ax.text(0, 1.1, label, transform=ax.transAxes) def plot_raster(ax, s, spikes, tmin, tmax): spikes_ms = [1000*(s[(s > tmin) & (s < tmax)] - tmin) for s in spikes] ax.show_spines('') ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster) ax.set_xlim(0, 1000*(tmax - tmin)) def plot_rate(ax, s, spikes, tmin, tmax, sigma=0.002): time = np.arange(0, tmin + tmax, 0.001) r, rsd = rate(time, spikes, sigma) mask = (time >= tmin) & (time <= tmax) time = time[mask] - tmin r = r[mask] ax.show_spines('') ax.plot(1000*time, r, **s.lsRate) ax.set_xlim(0, 1000*(tmax - tmin)) ax.set_ylim(0, 500) def plot_psd(ax, s, spikes, tmax, fmax, dt=0.0005, nfft=512): time = np.arange(0, tmax, dt) if nfft > len(time): print('nfft too large:', nfft, len(time)) # power spectrum: freqs = np.fft.fftfreq(nfft, dt) freqs = np.fft.fftshift(freqs) f0 = len(freqs)//4 f1 = 3*len(freqs)//4 segments = range(0, len(time) - nfft, nfft) p_rr = np.zeros(len(freqs)) n = 0 for i, spiket in enumerate(spikes): b, _ = np.histogram(spiket, time) b = b / dt for j, k in enumerate(segments): fourier_r = np.fft.fft(b[k:k + nfft] - np.mean(b), n=nfft) fourier_r = np.fft.fftshift(fourier_r) p_rr += np.abs(fourier_r*np.conj(fourier_r)) n += 1 freqs = freqs[f0:f1] scale = dt/nfft/n p_rr = p_rr[f0:f1]*scale # plot: mask = (freqs > 0) & (freqs <= fmax) freqs = freqs[mask] p_rr = p_rr[mask] #print(np.max(p_rr)) p_ref = 4000 ax.plot(freqs, 10*np.log10(p_rr/p_ref), **s.lsPower) ax.set_xlim(0, fmax) ax.set_ylim(-20, 0) if __name__ == '__main__': spikes, eodf, df1, df2, t0, t1, t2, t12 = load_data(data_path / cell) print(f'Loaded spike data for cell {cell}: ') print(f' EODf = {eodf:.1f}Hz') print(f' Df1 = {df1:.1f}Hz') print(f' Df2 = {df2:.1f}Hz') print(f' {len(spikes)} trials') s = plot_style() fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.6*s.plot_width), height_ratios=[1, 2, 1, 3, 6]) fmax = 250 tmin = 0.1 tmax = 0.2 twins = [[-t0, 0], [t1, t1 + t2], [0, t1], [t1 + t2, t1 + t2 + t12]] freqs = [eodf, df2, df1, df2] stim_freqs = [[None, None], [df2, None], [None, df1], [df1, df2]] for i in range(axs.shape[1]): tstart, tend = twins[i] plot_symbols(axs[0, i], s) plot_stimulus(axs[1, i], s, tmax - tmin, eodf, *stim_freqs[i]) sub_spikes = [times[(times >= tstart) & (times <= tend)] - tstart for times in spikes] plot_psd(axs[4, i], s, sub_spikes, tend - tstart, fmax) print(f'align spikes for frequency {freqs[i]:.0f}Hz:') sub_spikes = align_spikes(sub_spikes, abs(1/freqs[i])) plot_raster(axs[2, i], s, sub_spikes, tmin, tmax) plot_rate(axs[3, i], s, sub_spikes, tmin, tmax) #fig.savefig() plt.show() print()