import numpy as np import matplotlib.pyplot as plt from pathlib import Path from scipy.stats import norm from scipy.optimize import curve_fit from spectral import rate from plotstyle import plot_style cell = '2021-08-03-ac-invivo-1' data_path = Path('data') def load_spikes(cell_path, f1=797, f2=631): load = False spikes = [] index = 0 with open(cell_path / 'threefish-spikes.dat') as sf: for line in sf: if load: if ' before:' in line: t0 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) elif ' duration1 ' in line: t1 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) elif ' duration2 ' in line: t2 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) elif ' duration12 ' in line: t12 = 0.001*float(line.split(':')[1].strip().replace('ms', '')) elif line.startswith('# index '): if len(spikes) > 0: spikes[-1] = np.array(spikes[-1]) return spikes, eodf, df1, df2, t0, t1, t2, t12, index elif line.startswith('# trial:'): if len(spikes) > 0: spikes[-1] = np.array(spikes[-1]) spikes.append([]) elif len(line.strip()) > 0 and line[0] != '#': t = 0.001*float(line.strip()) spikes[-1].append(t) elif line.startswith('# index '): index += 1 elif line.startswith('# EOD rate '): eodf = float(line.split(':')[1].strip().replace('Hz', '')) elif line.startswith('# Deltaf1 '): df1 = float(line.split(':')[1].strip().replace('Hz', '')) elif line.startswith('# Deltaf2 '): df2 = float(line.split(':')[1].strip().replace('Hz', '')) if abs(eodf + df1 - f1) < 1 and abs(eodf + df2 - f2) < 1: #print(f'EODf={eodf:6.1f}Hz, Df1={df1:6.1f}Hz, Df2={df2:6.1f}Hz, EODf1={eodf + df1:6.1f}Hz, EODf2={eodf + df2:6.1f}Hz') load = True print(f'no spikes found for EODf1={f1:.1f}Hz and EODf2={f2:.1f}Hz') def load_am(cell_path, inx): load = False ams = [] index = 0 with open(cell_path / 'threefish-ams.dat') as sf: for line in sf: if load: if line.startswith('# index '): if len(ams) > 0: ams[-1] = np.array(ams[-1]) return ams elif line.startswith('# EOD rate '): print(f' EODf = {line.split(':')[1].strip()}') elif line.startswith('# Deltaf1 '): print(f' Df1 = {line.split(':')[1].strip()}') elif line.startswith('# Deltaf2 '): print(f' DF2 = {line.split(':')[1].strip()}') elif line.startswith('# trial:'): if len(ams) > 0: ams[-1] = np.array(ams[-1]) ams.append([]) elif len(line.strip()) > 0 and line[0] != '#': time, am = line.split() t = 0.001*float(time.strip()) a = float(am.strip()) ams[-1].append((t, a)) elif line.startswith('# index '): index += 1 if inx == index: load = True print(f'no AM found at index {inx}') def cosine(x, a, f, p, c): return a*np.cos(2*np.pi*f*x + p) + c def two_cosine(x, a1, f1, p1, a2, f2, p2, c): return a1*np.cos(2*np.pi*f1*x + p1) + a2*np.cos(2*np.pi*f2*x + p2) + c def am_phases(ams, eodf, df1, df2, t1, t2, t12): twins = (t1, t2, t12) dfs = ((df1,), (df2,), (df1, df2)) phases = np.zeros((len(ams), len(dfs) + 1)) for k in range(len(ams)): t0 = 0 time = ams[k][:, 0] am = ams[k][:, 1] for i in range(len(twins)): tw = twins[0] t1 = t0 + tw mask = (time >= t0) & (time <= t1) tam = time[mask] - t0 aam = am[mask] a = 0.5*(np.max(aam) - np.min(aam)) c = np.mean(aam) tt = np.linspace(0, tw, 1000) if len(dfs[i]) == 2: popt = [a/2, dfs[i][0], 0, a/2, dfs[i][1], 0, c] popt, _ = curve_fit(two_cosine, tam, aam, popt) aa = two_cosine(tt, *popt) phases[k, i] = popt[2] if popt[0] > 0 else popt[2] + np.pi phases[k, i + 1] = popt[5] if popt[3] > 0 else popt[5] + np.pi else: popt = [a, dfs[i][0], 0, c] popt, _ = curve_fit(cosine, tam, aam, popt) aa = cosine(tt, *popt) phases[k, i] = popt[2] if popt[0] > 0 else popt[2] + np.pi t0 = t1 return phases def align_spikes(spikes, freqs, phases): f1, f2 = freqs if f1 is None and f2 is None: return spikes p1 = phases[0] p2 = phases[1] if f2 is None: df = f1 p = p1 else: df = f2 p = p2 for i in range(len(spikes)): spikes[i] += p[i]/2/np.pi/df return spikes def baseline_rate(spikes, t0, t1): rates = [] for times in spikes: c = np.sum((times > t0) & (times < t1)) rates.append(c/(t1 - t0)) return np.mean(rates) def power_spectrum(spikes, tmax, dt=0.0005, nfft=512, p_ref=4000): time = np.arange(0, tmax, dt) if nfft > len(time): print('nfft too large:', nfft, len(time)) freqs = np.fft.fftfreq(nfft, dt) freqs = np.fft.fftshift(freqs) segments = range(0, len(time) - nfft, nfft) p_rr = np.zeros(len(freqs)) n = 0 for i, spiket in enumerate(spikes): b, _ = np.histogram(spiket, time) b = b / dt for j, k in enumerate(segments): fourier_r = np.fft.fft(b[k:k + nfft] - np.mean(b), n=nfft) fourier_r = np.fft.fftshift(fourier_r) p_rr += np.abs(fourier_r*np.conj(fourier_r)) n += 1 mask = freqs >= 0.0 freqs = freqs[mask] scale = dt/nfft/n p_rr = p_rr[mask]*scale power = 10*np.log10(p_rr/p_ref) return freqs, power def plot_symbols(ax, s, freqs): f1, f2 = freqs ax.show_spines('') ax.add_artist(plt.Rectangle((-1, -0.5), 2, 1, color=s.colors['black'])) ax.harrow(1.6, 0, 1.3, **s.asLine) ax.set_xlim(-6, 14) ax.set_ylim(-1, 1) if f1 is None and f2 is None: ax.text(3.5, 0, '$r$', va='center') else: ax.harrow(-2.8, 0, 1.3, **s.asLine) if f2 is None: ax.text(-3.2, 0, '$s_1(t)$', ha='right', va='center') ax.text(3.3, 0, '$r + r_1(t)$', va='center') elif f1 is None: ax.text(-3.2, 0, '$s_2(t)$', ha='right', va='center') ax.text(3.3, 0, '$r + r_2(t)$', va='center') else: ax.text(-3.2, 0, '$s_1(t) + s_2(t)$', ha='right', va='center') ax.text(3.3, 0, '$\\ne r + r_1(t) + r_2(t)$', va='center') def plot_stimulus(ax, s, tmax, eodf, freqs, c=0.1): time = np.arange(0, tmax, 0.0001) eod = np.cos(2*np.pi*eodf*time) am = np.ones(len(time)) ams = {} f1, f2 = freqs label = '$f_{EOD}$' if f1 is not None: eod += c*np.cos(2*np.pi*(eodf + f1)*time) am += c*np.cos(2*np.pi*f1*time) ams = s.lsF02 label += r' \& $f_1$' if f2 is not None: eod += c*np.cos(2*np.pi*(eodf + f2)*time) am += c*np.cos(2*np.pi*f2*time) ams = s.lsF01 label += r' \& $f_2$' if f1 is not None and f2 is not None: ams = s.lsF01_2 ax.show_spines('') ax.plot(1000*time, eod, **s.lsEOD) if len(ams) > 0: ax.plot(1000*time, am, **ams) ax.set_xlim(0, 1000*tmax) ax.set_ylim(-1.02 - 2*c, 1.02 + 2*c) ax.text(0.5, 1.2, label, ha='center', transform=ax.transAxes) def plot_raster(ax, s, spikes, tmin, tmax): spikes_ms = [1000*(s[(s > tmin) & (s < tmax)] - tmin) for s in spikes] ax.show_spines('') ax.eventplot(spikes_ms, linelengths=0.8, **s.lsRaster) ax.set_xlim(0, 1000*(tmax - tmin)) def plot_rate(ax, s, spikes, tmin, tmax, sigma=0.002): time = np.arange(0, tmin + tmax, 0.001) r, rsd = rate(time, spikes, sigma) mask = (time >= tmin) & (time <= tmax) time = time[mask] - tmin r = r[mask] ax.show_spines('l') ax.plot(1000*time, r, clip_on=False, **s.lsRate) ax.set_xlim(0, 1000*(tmax - tmin)) ax.set_ylim(0, 550) ax.set_ylabel('Rate', 'Hz') ax.set_yticks_delta(200) def plot_psd(ax, s, freqs, power, fmax, dt=0.0005, nfft=512): # plot: mask = freqs <= fmax freqs = freqs[mask] power = power[mask] ax.show_spines('lb') ax.plot(freqs, power, **s.lsPower) ax.set_xlim(0, fmax) ax.set_ylim(-20, 0) ax.set_xlabel('Frequency', 'Hz') ax.set_ylabel('Power', 'dB') ax.set_yticks_delta(10) def mark_freq(ax, freqs, power, f, label, style, xoffs=10, yoffs=0, toffs=0, angle=0): i = np.argmin(np.abs(freqs - abs(f))) p = power[i] f = freqs[i] ax.plot(f, p + 1 + yoffs, clip_on=False, **style) if label: yoffs += 3 + toffs if angle > 0: yoffs -= 1 ax.text(f - xoffs, p + yoffs, label, color=style['color'], rotation=angle) if __name__ == '__main__': spikes, eodf, df1, df2, t0, t1, t2, t12, index = load_spikes(data_path / cell) print(f'Loaded spike data for cell {cell} @ index {index}:') print(f' EODf = {eodf:.1f}Hz') print(f' Df1 = {df1:.1f}Hz') print(f' Df2 = {df2:.1f}Hz') print(f' {len(spikes)} trials') print(f'Load AMs for cell {cell} @ index {index}:') ams = load_am(data_path / cell, index) phases = am_phases(ams, eodf, df1, df2, t1, t2, t12) s = plot_style() fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.6*s.plot_width), height_ratios=[1, 0, 2, 1.3, 3, 0.7, 5]) fig.subplots_adjust(leftm=7.5, rightm=4.5, topm=1.5, bottomm=4, wspace=0.4, hspace=0.4) fmax = 250 tmin = 0.106 tmax = 0.206 twins = [[-t0, 0], [t1, t1 + t2], [0, t1], [t1 + t2, t1 + t2 + t12]] stim_freqs = [[None, None], [df2, None], [None, df1], [df1, df2]] stim_phases = [[None, None], [phases[:, 1], None], [None, phases[:, 0]], [phases[:, 2], phases[:, 3]]] base_rate = baseline_rate(spikes, *twins[0]) print(f'Baseline firing rate: {base_rate:.1f}Hz') powers = [] for i in range(axs.shape[1]): tstart, tend = twins[i] plot_symbols(axs[0, i], s, stim_freqs[i]) plot_stimulus(axs[1, i], s, tmax - tmin, eodf, stim_freqs[i]) sub_spikes = [times[(times >= tstart) & (times <= tend)] - tstart for times in spikes] freqs, power = power_spectrum(sub_spikes, tend - tstart) powers.append(power) plot_psd(axs[4, i], s, freqs, power, fmax) sub_spikes = align_spikes(sub_spikes, stim_freqs[i], stim_phases[i]) plot_raster(axs[2, i], s, sub_spikes, tmin, tmax) plot_rate(axs[3, i], s, sub_spikes, tmin, tmax) mark_freq(axs[4, 0], freqs, powers[0], base_rate, f'$r={base_rate:.0f}$\\,Hz', s.psF0, 30) mark_freq(axs[4, 1], freqs, powers[1], df2, f'$\\Delta f_1=f_1 - f_{{EOD}}={abs(df2):.0f}$\\,Hz', s.psF02) mark_freq(axs[4, 1], freqs, powers[1], 2*df2, f'$2\\Delta f_1={abs(2*df2):.0f}$\\,Hz', s.psF02) mark_freq(axs[4, 2], freqs, powers[2], df1, '', s.psF0) mark_freq(axs[4, 2], freqs, powers[2], df1, f'$\\Delta f_2=f_2 - f_{{EOD}}={abs(df1):.0f}$\\,Hz', s.psF01, 130, 1.5) mark_freq(axs[4, 3], freqs, powers[3], df2, '', s.psF02) mark_freq(axs[4, 3], freqs, powers[3], 2*df2, '', s.psF02) mark_freq(axs[4, 3], freqs, powers[3], df1, '', s.psF0) mark_freq(axs[4, 3], freqs, powers[3], df1, '', s.psF01, 130, 1.5) mark_freq(axs[4, 3], freqs, powers[3], abs(df1) + abs(df2) - 2, f'$\\Delta f_2 + \\Delta f_1={abs(df1) + abs(df2):.0f}$\\,Hz', s.psF012, 20, angle=40) mark_freq(axs[4, 3], freqs, powers[3], abs(df1) - abs(df2), f'$\\Delta f_2 - \\Delta f_1={abs(df1) - abs(df2):.0f}$\\,Hz', s.psF01_2, 50, toffs=5, angle=40) fig.common_yticks(axs[3, :]) fig.common_yticks(axs[4, :]) axs[3, 0].xscalebar(1, 0, 20, 'ms', ha='right') #axs[3, 0].scalebars(-0.03, 0, 20, 500, 'ms', 'Hz') #axs[4, 0].yscalebar(-0.03, 0.5, 10, 'dB', va='center') #fig.tag(axs.T) fig.tag(axs[0]) fig.savefig() #plt.show() print()