function [ stavg, stavgtime, spikesnippets, stimsnippets, meanrate ] = sta( stimulus, spikes, left, right ) % computes the spike-triggered average % stimulus: the stimulus as a nx2 matrix with the first column being time % in seconds and the second column being the actual stimulus % spikes: a cell array of vectors of spike times % left: the time to the left of each spike % right: the time to the right of each spike % returns % stavg: the spike-triggered average % stavgtime: the corresponding time axis % spikesnippets: the spike-triggered waveforms as a nspikes x stavgtimes % matrix % meanrate: the mean firing rate % time indices: dt = stimulus(2,1) - stimulus(1,1); wl = round( left/dt ); wr = round( right/dt ); nw = wl+wr+1; % total number of spikes: nspikes = 0; for k = 1:length( spikes ) nspikes = nspikes + length( spikes{k} ); end % loop over trials: spikesnippets = zeros( nspikes, nw ); nspikes = 0; for k = 1:length( spikes ) times = spikes{k}; for j = 1:length(times) % index of spike in stimulus: inx = round(times(j)/dt); if ( inx-wl > 0 ) & ( inx+wr <= size( stimulus, 1 ) ) nspikes = nspikes + 1; snip = stimulus( inx-wl:inx+wr, 2 ); spikesnippets( nspikes, : ) = snip; % - mean(snip); end end end % delete not used snippets: spikesnippets(nspikes+1:end,:) = []; stavgtime = [-left:dt:right]; meanrate = nspikes/length(spikes)/(stimulus(end,1)-stimulus(1,1)); % spike-triggered average: stavg = mean( spikesnippets, 1 ); % loop over stimulus: nstim = size( stimulus, 1 )-nw+1; stimsnippets = zeros( nstim, nw ); stimsnippetstime = zeros( nstim, 1 ); for j=1:nstim snip = stimulus( j:j+wr+wl, 2 ); stimsnippets(j,:) = snip; % - mean(snip); stimsnippetstime(j) = j*dt; end % projection onto sta: spikeonsta = spikesnippets * stavg'; stimonsta = stimsnippets * stavg'; % projection onto orthogonal to sta: orthos = null( stavg ); mixcoef = rand(size( orthos, 2 ), 1); mixcoef(length(mixcoef)/10:end) = 0.0; staortho = orthos*mixcoef; % stavg*staortho staortho = staortho/(staortho'*staortho); spikeonstaortho = spikesnippets * staortho; stimonstaortho = stimsnippets * staortho; % psth binned: ratetime = 0:dt:stimulus(end,1); % rate = zeros( size( ratetime ) ); % for k = 1:length( spikes ) % [n, ~] = hist( spikes{k}, ratetime ); % rate = rate + n/dt/length( spikes ); % end % kernel psth: kernelsigma = 0.001; windowtime = -5.0*kernelsigma:dt:5.0*kernelsigma; window = normpdf( windowtime, 0.0, kernelsigma )/length(spikes); % plot( 1000.0*windowtime, window ); w2 = floor( length( windowtime )/2 ); kernelrate = zeros( size( ratetime ) ); for k = 1:length( spikes ) times = spikes{k}; for j = 1:length(times) % index of spike in rate: inx = round(times(j)/dt); if ( inx - w2 > 0 ) & ( inx + w2 < length( kernelrate ) ) kernelrate(inx-w2:inx+w2) = kernelrate(inx-w2:inx+w2) + window; end end end if nargout == 0 % sta plot: figure( 1 ); subplot( 3, 2, 1 ) plot( 1000.0*stavgtime, staortho, '-r', 'LineWidth', 3 ) hold on; plot( 1000.0*stavgtime, stavg, '-b', 'LineWidth', 3 ) hold off; xlabel( 'time [ms]' ); ylabel( 'stimulus' ); title( 'Spike-triggered average' ); legend( 'ortho', 'STA' ); % 2D scatter of projections: subplot( 3, 2, 2 ) scatter( stimonsta, stimonstaortho, 40.0, 'b', 'filled', 'MarkerEdgeColor', 'white' ) hold on; scatter( spikeonsta, spikeonstaortho, 20.0, 'r', 'filled', 'MarkerEdgeColor', 'white', 'LineWidth', 1.0 ) hold off; xlabel( 'projection sta' ); ylabel( 'projection ortho' ); title( 'Projections' ); % histogram of projections onto sta: subplot( 3, 2, 3 ) [ n, projections ] = hist( stimonsta, 50 ); bw = (projections(2)-projections(1)); pstim = n / sum(n)/bw; bar( projections, pstim, 'b' ); hold on; [ n, ~ ] = hist( spikeonsta, projections ); pstimspike = n / sum(n)/bw; bar( projections, pstimspike, 'r' ); xlabel( 'projection x' ); title( 'Projection onto STA' ); xlim( [projections(1) projections(end)] ); hold off; legend( 'p(x)', 'p(x|spikes)' ); % nonlinearity for orthogonal projections: subplot( 3, 2, 5 ) nonlinearity = meanrate*pstimspike./pstim; plot( projections(pstim>0.01), nonlinearity(pstim>0.01), '-r', 'LineWidth', 3 ); xlim( [projections(1) projections(end)] ); ylim( [0 1000 ]) xlabel( 'projection x' ); ylabel( 'meanrate*p(x|spikes)/p(x) [Hz]' ); title( 'Nonlinearity'); % histogram of projections onto orthogonal sta: subplot( 3, 2, 4 ) [ n, ~] = hist( stimonstaortho, projections ); pstimortho = n / sum(n)/bw; bar( projections, pstimortho, 'b' ); hold on; [ n, ~ ] = hist( spikeonstaortho, projections ); pstimspikeortho = n / sum(n)/bw; bar( projections, pstimspikeortho, 'r' ); xlim( [projections(1) projections(end)] ); xlabel( 'projection x' ); title( 'Projection onto orthogonal to STA' ); hold off; legend( 'p(x)', 'p(x|spikes)' ); % nonlinearity for orthogonal projections: subplot( 3, 2, 6 ) nonlinearityortho = meanrate*pstimspikeortho./pstimortho; plot( projections(pstimortho>0.01), nonlinearityortho(pstimortho>0.01), '-r', 'LineWidth', 3 ); xlim( [projections(1) projections(end)] ); ylim( [0 1000 ]) xlabel( 'projection x' ); ylabel( 'meanrate*p(x|spikes)/p(x) [Hz]' ); title( 'Nonlinearity'); end % stimulus reconstruction with STA: stareconstruction = zeros( size( stimulus, 1 ), 1 ); for k = 1:length( spikes ) times = spikes{k}; for j = 1:length(times) % index of spike in stimulus: inx = round(times(j)/dt); if ( inx-wl > 0 ) & ( inx+wr <= size( stimulus, 1 ) ) stareconstruction( inx-wl:inx+wr ) = stareconstruction( inx-wl:inx+wr ) + stavg'; end end end stareconstruction = stareconstruction/length(spikes); if nargout == 0 % linear stimulus reconstruction: figure( 2 ) ax1 = subplot( 3, 1, 1 ); plot( 1000.0*stimulus(:,1), stimulus(:,2), 'g', 'LineWidth', 2 ); hold on; plot( 1000.0*stimulus(:,1), stareconstruction, 'r', 'LineWidth', 2 ); ylabel( 'stimulus' ); hold off; legend( 'stimulus', 'reconstruction' ); % stimulus projection onto sta: ax2 = subplot( 3, 1, 2 ); % plot( 1000.0*ratetime, rate, '-b', 'LineWidth', 2 ) plot( 1000.0*ratetime, kernelrate-mean(kernelrate), '-b', 'LineWidth', 2 ) hold on; plot( 1000.0*(stimsnippetstime+left), 2.0*stimonsta*meanrate, 'r', 'LineWidth', 2 ); ylabel( 'rate [Hz]' ); hold off; legend( 'rate', 'stimulus projection on STA' ); % sta with nonlinearity: ax3 = subplot( 3, 1, 3 ); plot( 1000.0*ratetime, kernelrate, '-b', 'LineWidth', 2 ) hold on; xmax = max(projections(pstim>0.01)); stimonstax = stimonsta; stimonstax( stimonstax > xmax ) = xmax; sinx = round((stimonstax-projections(1))/bw); sinx( sinx < 1 ) = 1; stastimulus = nonlinearity( sinx ); plot( 1000.0*(stimsnippetstime+left), stastimulus, 'r', 'LineWidth', 2 ); legend( 'rate', 'LNP' ); xlabel( 'time [ms]' ); ylabel( 'rate [Hz]' ); hold off; linkaxes( [ax1 ax2 ax3], 'x' ); end end