import numpy as np import matplotlib.pyplot as plt from pathlib import Path from scipy.stats import linregress from numba import jit from plotstyle import plot_style, lighter, darker model_cell = '2018-05-08-ad-invivo-1' # 228Hz, CV=0.67 data_path = Path('data') sims_path = data_path / 'simulations' def load_data(file_path): data = np.load(file_path) ratebase = float(data['ratebase']) cvbase = float(data['cvbase']) beatf1 = float(data['beatf1']) beatf2 = float(data['beatf2']) contrasts = data['contrasts'] powerf1 = data['powerf1'] powerf2 = data['powerf2'] powerfsum = data['powerfsum'] powerfdiff = data['powerfdiff'] return (ratebase, cvbase, beatf1, beatf2, contrasts, powerf1, powerf2, powerfsum, powerfdiff) def load_models(file): """ Load model parameter from csv file. Parameters ---------- file: string Name of file with model parameters. Returns ------- parameters: list of dict For each cell a dictionary with model parameters. """ parameters = [] with open(file, 'r') as file: header_line = file.readline() header_parts = header_line.strip().split(",") keys = header_parts for line in file: line_parts = line.strip().split(",") parameter = {} for i in range(len(keys)): parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i] parameters.append(parameter) return parameters def cell_parameters(parameters, cell_name): for params in parameters: if params['cell'] == cell_name: return params print('cell', cell_name, 'not found!') exit() return None @jit(nopython=True) def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0, threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1, v_offset=-10.0, mem_tau=0.015, noise_strength=0.05, input_scaling=60.0, dend_tau=0.001, ref_period=0.001): """ Simulate a P-unit. Returns ------- spike_times: 1-D array Simulated spike times in seconds. """ # initial conditions: v_dend = stimulus[0] v_mem = v_zero adapt = a_zero # prepare noise: noise = np.random.randn(len(stimulus)) noise *= noise_strength / np.sqrt(deltat) # rectify stimulus array: stimulus = stimulus.copy() stimulus[stimulus < 0.0] = 0.0 # integrate: spike_times = [] for i in range(len(stimulus)): v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat v_mem += (v_base - v_mem + v_offset + ( v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat adapt += -adapt / tau_a * deltat # refractory period: if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2: v_mem = v_base # threshold crossing: if v_mem > threshold: v_mem = v_base spike_times.append(i * deltat) adapt += delta_a / tau_a return np.array(spike_times) def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials): tini = 0.2 model_params = dict(parameter) cell = model_params.pop('cell') eodf0 = model_params.pop('EODf') time = np.arange(-tini, tmax, model_params['deltat']) stimulus = np.sin(2*np.pi*eodf0*time) stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time) stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time) spikes = [] for i in range(trials): model_params['v_zero'] = np.random.rand() model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn() spiket = simulate(stimulus, **model_params) spikes.append(spiket[spiket > tini] - tini) return spikes def plot_am(ax, s, alpha, beatf1, beatf2, tmax): time = np.arange(0, tmax, 0.0001) am = alpha*np.sin(2*np.pi*beatf1*time) am += alpha*np.sin(2*np.pi*beatf2*time) ax.show_spines('l') ax.plot(1000*time, -100*am, **s.lsAM) ax.set_xlim(0, 1000*tmax) ax.set_ylim(-13, 13) ax.set_yticks_delta(10) #ax.set_xlabel('Time', 'ms') ax.set_ylabel('AM', r'\%') ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%', transform=ax.transAxes, ha='right') def plot_raster(ax, s, spikes, tmax): spikes_ms = [1000*s[s f - df) & (freqs < f + df)] return np.max(psd_snippet) def amplitude(power): power -= power[0] power[power<0] = 0 return np.sqrt(power) def amplitude_linearfit(contrast, power, max_contrast): power -= power[0] power[power<0] = 0 ampl = np.sqrt(power) a = ampl[contrast <= max_contrast] c = contrast[contrast <= max_contrast] r = linregress(c, a) return r.intercept + r.slope*contrast def amplitude_squarefit(contrast, power, max_contrast): power -= power[0] power[power<0] = 0 ampl = np.sqrt(power) a = np.sqrt(ampl[contrast <= max_contrast]) c = contrast[contrast <= max_contrast] r = linregress(c, a) return (r.intercept + r.slope*contrast)**2 def plot_peaks(ax, s, alphas, contrasts, powerf1, powerf2, powerfsum, powerfdiff): cmax = 10 contrasts *= 100 ax.plot(contrasts, amplitude_linearfit(contrasts, powerf1, 4), **s.lsF01m) ax.plot(contrasts, amplitude_linearfit(contrasts, powerf2, 2), **s.lsF02m) ax.plot(contrasts, amplitude_squarefit(contrasts, powerfsum, 4), **s.lsF012m) ax.plot(contrasts, amplitude_squarefit(contrasts, powerfdiff, 4), **s.lsF01_2m) ax.plot(contrasts, amplitude(powerf1), **s.lsF01) ax.plot(contrasts, amplitude(powerf2), **s.lsF02) mask = contrasts < cmax ax.plot(contrasts[mask], amplitude(powerfsum)[mask], clip_on=False, **s.lsF012) ax.plot(contrasts[mask], amplitude(powerfdiff)[mask], clip_on=False, **s.lsF01_2) ymax = 60 for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']): ax.plot(100*alpha, ymax*0.95, 'vk', ms=4, clip_on=False) ax.text(100*alpha, ymax, tag, ha='center') #ax.axvline(contrast, **s.lsGrid) #ax.text(contrast, 630, tag, ha='center') cthresh1 = 1.2 cthresh2 = 3.5 print(f'Linear regime ends at a contrast of about {cthresh1:4.1f}%') print(f'Weakly non-linear regime ends at a contrast of about {cthresh2:4.1f}%') ax.axvline(cthresh1, **s.lsLine) ax.axvline(cthresh2, **s.lsLine) yoffs = 35 ax.text(cthresh1/2, yoffs, 'linear\nregime', ha='center', va='center') ax.text((cthresh1 + cthresh2)/2, yoffs, 'weakly\nnonlinear\nregime', ha='center', va='center') ax.text(5.5, yoffs, 'strongly\nnonlinear\nregime', ha='center', va='center') ax.set_xlim(0, cmax) ax.set_ylim(0, ymax) ax.set_xticks_delta(2) ax.set_yticks_delta(20) ax.set_xlabel('Contrast', r'\%') ax.set_ylabel('Amplitude', 'Hz') if __name__ == '__main__': ratebase, cvbase, beatf1, beatf2, \ contrasts, powerf1, powerf2, powerfsum, powerfdiff = \ load_data(sims_path / f'{model_cell}-contrastpeaks.npz') alphas = [0.002, 0.01, 0.03, 0.06] parameters = load_models(data_path / 'punitmodels.csv') cell = cell_parameters(parameters, model_cell) nfft = 2**18 print(f'Loaded data for cell {model_cell}: ') print(f' baseline rate = {ratebase:.0f}Hz, CV = {cvbase:.2f}') print(f' f1 = {beatf1:.0f}Hz, f2 = {beatf2:.0f}Hz') print() s = plot_style() fig, (axes, axa) = plt.subplots(2, 1, height_ratios=[4, 3], cmsize=(s.plot_width, 0.6*s.plot_width)) fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5, hspace=0.6) axe = axes.subplots(3, 4, wspace=0.4, hspace=0.2, height_ratios=[1, 2, 3]) fig.show_spines('lb') # example power spectra: for c, alpha in enumerate(alphas): path = sims_path / f'{model_cell}-contrastspectrum-{1000*alpha:03.0f}.npz' plot_example(axe[0, c], axe[1, c], axe[2, c], s, path, cell, alpha, beatf1, beatf2, nfft, 100) axe[1, 0].xscalebar(1, -0.1, 20, 'ms', ha='right') axe[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8), ncol=5, columnspacing=2) fig.common_yspines(axe[0, :]) fig.common_yticks(axe[2, :]) fig.tag(axe[0, :], xoffs=-3, yoffs=1.6) # contrast dependence: plot_peaks(axa, s, alphas, contrasts, powerf1, powerf2, powerfsum, powerfdiff) fig.tag(axa, yoffs=2) fig.savefig() print()