import numpy as np import matplotlib.pyplot as plt from pathlib import Path from scipy.stats import linregress from numba import jit from thunderlab.tabledata import TableData from plotstyle import plot_style, lighter, darker data_path = Path('data') sims_path = data_path / 'simulations' def load_models(file): """ Load model parameter from csv file. Parameters ---------- file: string Name of file with model parameters. Returns ------- parameters: list of dict For each cell a dictionary with model parameters. """ parameters = [] with file.open('r') as file: header_line = file.readline() header_parts = header_line.strip().split(",") keys = header_parts for line in file: line_parts = line.strip().split(",") parameter = {} for i in range(len(keys)): parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i] parameters.append(parameter) return parameters def cell_parameters(parameters, cell_name): for params in parameters: if params['cell'] == cell_name: return params print('cell', cell_name, 'not found!') exit() return None @jit(nopython=True) def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0, threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1, v_offset=-10.0, mem_tau=0.015, noise_strength=0.05, input_scaling=60.0, dend_tau=0.001, ref_period=0.001): """ Simulate a P-unit. Returns ------- spike_times: 1-D array Simulated spike times in seconds. """ # initial conditions: v_dend = stimulus[0] v_mem = v_zero adapt = a_zero # prepare noise: noise = np.random.randn(len(stimulus)) noise *= noise_strength / np.sqrt(deltat) # rectify stimulus array: stimulus = stimulus.copy() stimulus[stimulus < 0.0] = 0.0 # integrate: spike_times = [] for i in range(len(stimulus)): v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat v_mem += (v_base - v_mem + v_offset + ( v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat adapt += -adapt / tau_a * deltat # refractory period: if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2: v_mem = v_base # threshold crossing: if v_mem > threshold: v_mem = v_base spike_times.append(i * deltat) adapt += delta_a / tau_a return np.array(spike_times) def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials): tini = 0.2 model_params = dict(parameter) cell = model_params.pop('cell') eodf0 = model_params.pop('EODf') time = np.arange(-tini, tmax, model_params['deltat']) stimulus = np.sin(2*np.pi*eodf0*time) stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time) stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time) spikes = [] for i in range(trials): model_params['v_zero'] = np.random.rand() model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn() spiket = simulate(stimulus, **model_params) spikes.append(spiket[spiket > tini] - tini) return spikes def plot_am(ax, s, alpha, beatf1, beatf2, tmax): time = np.arange(0, tmax, 0.0001) am = alpha*np.sin(2*np.pi*beatf1*time) am += alpha*np.sin(2*np.pi*beatf2*time) ax.show_spines('l') ax.plot(1000*time, -100*am, **s.lsAM) ax.set_xlim(0, 1000*tmax) ax.set_ylim(-50, 50) #ax.set_xlabel('Time', 'ms') ax.set_ylabel('AM', r'\%') ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%', transform=ax.transAxes, ha='right') def plot_raster(ax, s, spikes, tmax): spikes_ms = [1000*s[s f - df) & (freqs < f + df)] return np.max(psd_snippet) def compute_peaks(name, cell, alpha_max, beatf1, beatf2, nfft, trials): data_file = sims_path / f'{name}-contrastpeaks.csv' data = TableData(data_file) return data """ if data_file.exists(): data = TableData(data_file) return data dt = 0.0001 tmax = nfft*dt alphas = np.linspace(0, alpha_max, 200) ampl_f1 = np.zeros(len(alphas)) ampl_f2 = np.zeros(len(alphas)) ampl_sum = np.zeros(len(alphas)) ampl_diff = np.zeros(len(alphas)) for k, alpha in enumerate(alphas): print(alpha) spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials) freqs, psd = compute_power(spikes, nfft, dt) ampl_f1[k] = peak_ampl(freqs, psd, beatf1) ampl_f2[k] = peak_ampl(freqs, psd, beatf2) ampl_sum[k] = peak_ampl(freqs, psd, beatf1 + beatf2) ampl_diff[k] = peak_ampl(freqs, psd, beatf2 - beatf1) data = TableData() data.append('contrast', '%', '%.1f', 100*alphas) data.append('f1', 'Hz', '%g', ampl_f1) data.append('f2', 'Hz', '%g', ampl_f2) data.append('f1+f2', 'Hz', '%g', ampl_sum) data.append('f2-f1', 'Hz', '%g', ampl_diff) data.write(data_file) return data """ def amplitude(power): power -= power[0] power[power<0] = 0 return np.sqrt(power) def amplitude_linearfit(contrast, power, max_contrast): power -= power[0] power[power<0] = 0 ampl = np.sqrt(power) a = ampl[contrast <= max_contrast] c = contrast[contrast <= max_contrast] r = linregress(c, a) return r.intercept + r.slope*contrast def amplitude_squarefit(contrast, power, max_contrast): power -= power[0] power[power<0] = 0 ampl = np.sqrt(power) a = np.sqrt(ampl[contrast <= max_contrast]) c = contrast[contrast <= max_contrast] r = linregress(c, a) return (r.intercept + r.slope*contrast)**2 def plot_peaks(ax, s, data, alphas): contrast = data[:, 'contrast'] ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f1'], 4), **s.lsF01m) ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f2'], 2), **s.lsF02m) ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f1+f2'], 4), **s.lsF012m) ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f2-f1'], 4), **s.lsF01_2m) ax.plot(contrast, amplitude(data[:, 'f1']), **s.lsF01) ax.plot(contrast, amplitude(data[:, 'f2']), **s.lsF02) ax.plot(contrast, amplitude(data[:, 'f1+f2']), **s.lsF012) ax.plot(contrast, amplitude(data[:, 'f2-f1']), **s.lsF01_2) for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']): contrast = 100*alpha ax.plot(contrast, 630, 'vk', ms=4, clip_on=False) ax.text(contrast, 660, tag, ha='center') #ax.axvline(contrast, **s.lsGrid) #ax.text(contrast, 630, tag, ha='center') ax.axvline(1.5, **s.lsLine) ax.axvline(4, **s.lsLine) yoffs = 340 ax.text(1.5/2, yoffs, 'linear\nregime', ha='center', va='center') ax.text((1.5 + 4)/2, yoffs, 'weakly\nnonlinear\nregime', ha='center', va='center') ax.text(10, yoffs, 'strongly\nnonlinear\nregime', ha='center', va='center') ax.set_xlim(0, 16.5) ax.set_ylim(0, 600) ax.set_xticks_delta(5) ax.set_yticks_delta(300) ax.set_xlabel('Contrast', r'\%') ax.set_ylabel('Amplitude', 'Hz') if __name__ == '__main__': parameters = load_models(data_path / 'punitmodels.csv') cell_name = '2013-01-08-aa-invivo-1' # 132Hz, CV=0.16: perfect! beatf1 = 40 beatf2 = 132 # cell_name = '2012-07-03-ak-invivo-1' # 128Hz, CV=0.24 # cell_name = '2018-05-08-ae-invivo-1' # 142Hz, CV=0.48 cell = cell_parameters(parameters, cell_name) s = plot_style() nfft = 2**18 fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.8*s.plot_width), height_ratios=[1, 1.5, 2, 1.5, 4]) fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5, wspace=0.3, hspace=0.3) ax0 = fig.merge(axs[3, :]) ax0.set_visible(False) axa = fig.merge(axs[4, :]) fig.show_spines('lb') alphas = [0.01, 0.03, 0.05, 0.16] #alphas = [0.002, 0.01, 0.05, 0.1] for c, alpha in enumerate(alphas): plot_example(axs[0, c], axs[1, c], axs[2, c], s, cell, alpha, beatf1, beatf2, nfft, 100) axs[1, 0].xscalebar(1, -0.1, 30, 'ms', ha='right') axs[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8), ncol=5, columnspacing=2) data = compute_peaks(cell_name, cell, 0.2, beatf1, beatf2, nfft, 1000) plot_peaks(axa, s, data, alphas) fig.common_yspines(axs[0, :]) fig.common_yticks(axs[2, :]) #fig.common_xlabels(axs[2, :]) fig.tag(axs[0, :], xoffs=-2, yoffs=1.6) fig.tag(axa) fig.savefig()