326 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			326 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import numpy as np
 | |
| import matplotlib.pyplot as plt
 | |
| from pathlib import Path
 | |
| from scipy.stats import linregress
 | |
| from numba import jit
 | |
| from plotstyle import plot_style, lighter, darker
 | |
| 
 | |
| 
 | |
| model_cell = '2018-05-08-ad-invivo-1'      # 228Hz, CV=0.67
 | |
| 
 | |
| data_path = Path('data')
 | |
| sims_path = data_path / 'simulations'
 | |
| 
 | |
| 
 | |
| def load_data(file_path):
 | |
|     data = np.load(file_path)
 | |
|     ratebase = float(data['ratebase'])
 | |
|     cvbase = float(data['cvbase'])
 | |
|     beatf1 = float(data['beatf1'])
 | |
|     beatf2 = float(data['beatf2'])
 | |
|     contrasts = data['contrasts']
 | |
|     powerf1 = data['powerf1']
 | |
|     powerf2 = data['powerf2']
 | |
|     powerfsum = data['powerfsum']
 | |
|     powerfdiff = data['powerfdiff']
 | |
|     return (ratebase, cvbase, beatf1, beatf2,
 | |
|             contrasts, powerf1, powerf2, powerfsum, powerfdiff)
 | |
| 
 | |
| 
 | |
| def load_models(file):
 | |
|     """ Load model parameter from csv file.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     file: string
 | |
|         Name of file with model parameters.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     parameters: list of dict
 | |
|         For each cell a dictionary with model parameters.
 | |
|     """
 | |
|     parameters = []
 | |
|     with open(file, 'r') as file:
 | |
|         header_line = file.readline()
 | |
|         header_parts = header_line.strip().split(",")
 | |
|         keys = header_parts
 | |
|         for line in file:
 | |
|             line_parts = line.strip().split(",")
 | |
|             parameter = {}
 | |
|             for i in range(len(keys)):
 | |
|                 parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i]
 | |
|             parameters.append(parameter)
 | |
|     return parameters
 | |
| 
 | |
| 
 | |
| def cell_parameters(parameters, cell_name):
 | |
|     for params in parameters:
 | |
|         if params['cell'] == cell_name:
 | |
|             return params
 | |
|     print('cell', cell_name, 'not found!')
 | |
|     exit()
 | |
|     return None
 | |
| 
 | |
| 
 | |
| @jit(nopython=True)
 | |
| def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0,
 | |
|              threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1,
 | |
|              v_offset=-10.0, mem_tau=0.015, noise_strength=0.05,
 | |
|              input_scaling=60.0, dend_tau=0.001, ref_period=0.001):
 | |
|     """ Simulate a P-unit.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     spike_times: 1-D array
 | |
|         Simulated spike times in seconds.
 | |
|     """    
 | |
|     # initial conditions:
 | |
|     v_dend = stimulus[0]
 | |
|     v_mem = v_zero
 | |
|     adapt = a_zero
 | |
| 
 | |
|     # prepare noise:    
 | |
|     noise = np.random.randn(len(stimulus))
 | |
|     noise *= noise_strength / np.sqrt(deltat)
 | |
| 
 | |
|     # rectify stimulus array:
 | |
|     stimulus = stimulus.copy()
 | |
|     stimulus[stimulus < 0.0] = 0.0
 | |
| 
 | |
|     # integrate:
 | |
|     spike_times = []
 | |
|     for i in range(len(stimulus)):
 | |
|         v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat
 | |
|         v_mem += (v_base - v_mem + v_offset + (
 | |
|                     v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat
 | |
|         adapt += -adapt / tau_a * deltat
 | |
| 
 | |
|         # refractory period:
 | |
|         if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2:
 | |
|             v_mem = v_base
 | |
| 
 | |
|         # threshold crossing:
 | |
|         if v_mem > threshold:
 | |
|             v_mem = v_base
 | |
|             spike_times.append(i * deltat)
 | |
|             adapt += delta_a / tau_a
 | |
| 
 | |
|     return np.array(spike_times)
 | |
| 
 | |
| 
 | |
| def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials):
 | |
|     tini = 0.2
 | |
|     model_params = dict(parameter)
 | |
|     cell = model_params.pop('cell')
 | |
|     eodf0 = model_params.pop('EODf')
 | |
|     time = np.arange(-tini, tmax, model_params['deltat'])
 | |
|     stimulus = np.sin(2*np.pi*eodf0*time)
 | |
|     stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time)
 | |
|     stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time)
 | |
|     spikes = []
 | |
|     for i in range(trials):
 | |
|         model_params['v_zero'] = np.random.rand()
 | |
|         model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn()
 | |
|         spiket = simulate(stimulus, **model_params)
 | |
|         spikes.append(spiket[spiket > tini] - tini)
 | |
|     return spikes
 | |
| 
 | |
|     
 | |
| def plot_am(ax, s, alpha, beatf1, beatf2, tmax):
 | |
|     time = np.arange(0, tmax, 0.0001)
 | |
|     am = alpha*np.sin(2*np.pi*beatf1*time)
 | |
|     am += alpha*np.sin(2*np.pi*beatf2*time)
 | |
|     ax.show_spines('l')
 | |
|     ax.plot(1000*time, -100*am, **s.lsAM)
 | |
|     ax.set_xlim(0, 1000*tmax)
 | |
|     ax.set_ylim(-13, 13)
 | |
|     ax.set_yticks_delta(10)
 | |
|     #ax.set_xlabel('Time', 'ms')
 | |
|     ax.set_ylabel('AM', r'\%')
 | |
|     ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%',
 | |
|             transform=ax.transAxes, ha='right')
 | |
| 
 | |
|     
 | |
| def plot_raster(ax, s, spikes, tmax):
 | |
|     spikes_ms = [1000*s[s<tmax] for s in spikes[:16]]
 | |
|     ax.show_spines('')
 | |
|     ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
 | |
|     ax.set_xlim(0, 1000*tmax)
 | |
|     #ax.set_xlabel('Time', 'ms')
 | |
|     #ax.set_ylabel('Trials')
 | |
| 
 | |
| 
 | |
| def compute_power(path, contrast, spikes, nfft, dt):
 | |
|     if not path.exists():
 | |
|         print(f'    Compute power spectrum for contrast = {100*contrast:4.1f}%')
 | |
|         psds = []
 | |
|         time = np.arange(nfft + 1)*dt
 | |
|         tmax = nfft*dt
 | |
|         for s in spikes:
 | |
|             b, _ = np.histogram(s, time)
 | |
|             b = b / dt
 | |
|             fourier = np.fft.rfft(b - np.mean(b))
 | |
|             psds.append(np.abs(fourier)**2)
 | |
|             freqs = np.fft.rfftfreq(nfft, dt)
 | |
|             prr = np.mean(psds, 0)*dt/nfft
 | |
|         np.savez(path, nfft=nfft, deltat=dt, nsegs=len(spikes),
 | |
|                  freqs=freqs, prr=prr)
 | |
|     else:
 | |
|         print(f'    Load power spectrum for contrast = {100*contrast:4.1f}%')
 | |
|         data = np.load(path)
 | |
|         freqs = data['freqs']
 | |
|         prr = data['prr']
 | |
|     return freqs, prr
 | |
| 
 | |
| 
 | |
| def decibel(x):
 | |
|     return 10*np.log10(x/1e8)
 | |
| 
 | |
| 
 | |
| def plot_psd(ax, s, path, contrast, spikes, nfft, dt, beatf1, beatf2):
 | |
|     offs = 4
 | |
|     freqs, psd = compute_power(path, contrast, spikes, nfft, dt)
 | |
|     psd /= freqs[1]
 | |
|     ax.plot(freqs, decibel(psd), **s.lsPower)
 | |
|     ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs,
 | |
|             label=r'$r$', clip_on=False, **s.psF0)
 | |
|     ax.plot(beatf1, decibel(peak_ampl(freqs, psd, beatf1)) + offs,
 | |
|             label=r'$\Delta f_1$', clip_on=False, **s.psF01)
 | |
|     ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs + 5.5,
 | |
|             label=r'$\Delta f_2$', clip_on=False, **s.psF02)
 | |
|     ax.plot(beatf2 - beatf1, decibel(peak_ampl(freqs, psd, beatf2 - beatf1)) + offs,
 | |
|             label=r'$\Delta f_2 - \Delta f_1$', clip_on=False, **s.psF01_2)
 | |
|     ax.plot(beatf1 + beatf2, decibel(peak_ampl(freqs, psd, beatf1 + beatf2)) + offs,
 | |
|             label=r'$\Delta f_1 + \Delta f_2$', clip_on=False, **s.psF012)
 | |
|     ax.set_xlim(0, 300)
 | |
|     ax.set_ylim(-60, 0)
 | |
|     ax.set_xlabel('Frequency', 'Hz')
 | |
|     ax.set_ylabel('Power [dB]')
 | |
| 
 | |
| 
 | |
| def plot_example(axs, axr, axp, s, path, cell, alpha, beatf1, beatf2,
 | |
|                  nfft, trials):
 | |
|     dt = 0.0001
 | |
|     tmax = nfft*dt
 | |
|     t1 = 0.1
 | |
|     spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
 | |
|     plot_am(axs, s, alpha, beatf1, beatf2, t1)
 | |
|     plot_raster(axr, s, spikes, t1)
 | |
|     plot_psd(axp, s, path, alpha, spikes, nfft, dt, beatf1, beatf2)
 | |
| 
 | |
| 
 | |
| def peak_ampl(freqs, psd, f):
 | |
|     df = 2
 | |
|     psd_snippet = psd[(freqs > f - df) & (freqs < f + df)]
 | |
|     return np.max(psd_snippet)
 | |
| 
 | |
| 
 | |
| def amplitude(power):
 | |
|     power -= power[0]
 | |
|     power[power<0] = 0
 | |
|     return np.sqrt(power)
 | |
| 
 | |
| 
 | |
| def amplitude_linearfit(contrast, power, max_contrast):
 | |
|     power -= power[0]
 | |
|     power[power<0] = 0
 | |
|     ampl = np.sqrt(power)
 | |
|     a = ampl[contrast <= max_contrast]
 | |
|     c = contrast[contrast <= max_contrast]
 | |
|     r = linregress(c, a)
 | |
|     return r.intercept + r.slope*contrast
 | |
| 
 | |
| 
 | |
| def amplitude_squarefit(contrast, power, max_contrast):
 | |
|     power -= power[0]
 | |
|     power[power<0] = 0
 | |
|     ampl = np.sqrt(power)
 | |
|     a = np.sqrt(ampl[contrast <= max_contrast])
 | |
|     c = contrast[contrast <= max_contrast]
 | |
|     r = linregress(c, a)
 | |
|     return (r.intercept + r.slope*contrast)**2
 | |
| 
 | |
| 
 | |
| def plot_peaks(ax, s, alphas, contrasts, powerf1, powerf2, powerfsum,
 | |
|                powerfdiff):
 | |
|     cmax = 10
 | |
|     contrasts *= 100
 | |
|     ax.plot(contrasts, amplitude_linearfit(contrasts, powerf1, 4),
 | |
|             **s.lsF01m)
 | |
|     ax.plot(contrasts, amplitude_linearfit(contrasts, powerf2, 2),
 | |
|             **s.lsF02m)
 | |
|     ax.plot(contrasts, amplitude_squarefit(contrasts, powerfsum, 4),
 | |
|             **s.lsF012m)
 | |
|     ax.plot(contrasts, amplitude_squarefit(contrasts, powerfdiff, 4),
 | |
|             **s.lsF01_2m)
 | |
|     ax.plot(contrasts, amplitude(powerf1), **s.lsF01)
 | |
|     ax.plot(contrasts, amplitude(powerf2), **s.lsF02)
 | |
|     mask = contrasts < cmax
 | |
|     ax.plot(contrasts[mask], amplitude(powerfsum)[mask],
 | |
|             clip_on=False, **s.lsF012)
 | |
|     ax.plot(contrasts[mask], amplitude(powerfdiff)[mask],
 | |
|             clip_on=False, **s.lsF01_2)
 | |
|     ymax = 60
 | |
|     for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']):
 | |
|         ax.plot(100*alpha, ymax*0.95, 'vk', ms=4, clip_on=False)
 | |
|         ax.text(100*alpha, ymax, tag, ha='center')
 | |
|         #ax.axvline(contrast, **s.lsGrid)
 | |
|         #ax.text(contrast, 630, tag, ha='center')
 | |
|     ax.axvline(1.2, **s.lsLine)
 | |
|     ax.axvline(3.5, **s.lsLine)
 | |
|     yoffs = 35
 | |
|     ax.text(1.2/2, yoffs, 'linear\nregime',
 | |
|             ha='center', va='center')
 | |
|     ax.text((1.2 + 3.5)/2, yoffs, 'weakly\nnonlinear\nregime',
 | |
|             ha='center', va='center')
 | |
|     ax.text(5.5, yoffs, 'strongly\nnonlinear\nregime',
 | |
|             ha='center', va='center')
 | |
|     ax.set_xlim(0, cmax)
 | |
|     ax.set_ylim(0, ymax)
 | |
|     ax.set_xticks_delta(2)
 | |
|     ax.set_yticks_delta(20)
 | |
|     ax.set_xlabel('Contrast', r'\%')
 | |
|     ax.set_ylabel('Amplitude', 'Hz')
 | |
| 
 | |
|     
 | |
| if __name__ == '__main__':
 | |
|     ratebase, cvbase, beatf1, beatf2, \
 | |
|         contrasts, powerf1, powerf2, powerfsum, powerfdiff = \
 | |
|             load_data(sims_path / f'{model_cell}-contrastpeaks.npz')
 | |
|     alphas = [0.002, 0.01, 0.03, 0.06]
 | |
| 
 | |
|     parameters = load_models(data_path / 'punitmodels.csv')
 | |
|     cell = cell_parameters(parameters, model_cell)
 | |
|     nfft = 2**18
 | |
|     
 | |
|     print(f'Loaded data for cell {model_cell}: '
 | |
|           f'baseline rate = {ratebase:.0f}Hz, CV = {cvbase:.2f}')
 | |
| 
 | |
|     s = plot_style()
 | |
|     fig, (axes, axa) = plt.subplots(2, 1, height_ratios=[4, 3],
 | |
|                                     cmsize=(s.plot_width, 0.6*s.plot_width))
 | |
|     fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5, hspace=0.6)
 | |
|     axe = axes.subplots(3, 4, wspace=0.4, hspace=0.2,
 | |
|                         height_ratios=[1, 2, 3])
 | |
|     fig.show_spines('lb')
 | |
|     
 | |
|     # example power spectra:
 | |
|     for c, alpha in enumerate(alphas):
 | |
|         path = sims_path / f'{model_cell}-contrastspectrum-{1000*alpha:03.0f}.npz'
 | |
|         plot_example(axe[0, c], axe[1, c], axe[2, c], s, path,
 | |
|                      cell, alpha, beatf1, beatf2, nfft, 100)
 | |
|     axe[1, 0].xscalebar(1, -0.1, 20, 'ms', ha='right')
 | |
|     axe[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8),
 | |
|                      ncol=5, columnspacing=2)
 | |
|     fig.common_yspines(axe[0, :])
 | |
|     fig.common_yticks(axe[2, :])
 | |
|     fig.tag(axe[0, :], xoffs=-3, yoffs=1.6)
 | |
|     
 | |
|     # contrast dependence:
 | |
|     plot_peaks(axa, s, alphas, contrasts, powerf1, powerf2,
 | |
|                powerfsum, powerfdiff)
 | |
|     fig.tag(axa, yoffs=2)
 | |
|     fig.savefig()
 | |
|     print()
 |