This repository has been archived on 2021-05-17. You can view files and clone it, but cannot push or open issues or pull requests.
scientificComputing/linearalgebra/code/sta.m
2014-11-12 18:39:02 +01:00

228 lines
7.9 KiB
Matlab

function [ stavg, stavgtime, spikesnippets, stimsnippets, meanrate ] = sta( stimulus, spikes, left, right )
% computes the spike-triggered average
% stimulus: the stimulus as a nx2 matrix with the first column being time
% in seconds and the second column being the actual stimulus
% spikes: a cell array of vectors of spike times
% left: the time to the left of each spike
% right: the time to the right of each spike
% returns
% stavg: the spike-triggered average
% stavgtime: the corresponding time axis
% spikesnippets: the spike-triggered waveforms as a nspikes x stavgtimes
% matrix
% meanrate: the mean firing rate
% time indices:
dt = stimulus(2,1) - stimulus(1,1);
wl = round( left/dt );
wr = round( right/dt );
nw = wl+wr+1;
% total number of spikes:
nspikes = 0;
for k = 1:length( spikes )
nspikes = nspikes + length( spikes{k} );
end
% loop over trials:
spikesnippets = zeros( nspikes, nw );
nspikes = 0;
for k = 1:length( spikes )
times = spikes{k};
for j = 1:length(times)
% index of spike in stimulus:
inx = round(times(j)/dt);
if ( inx-wl > 0 ) & ( inx+wr <= size( stimulus, 1 ) )
nspikes = nspikes + 1;
snip = stimulus( inx-wl:inx+wr, 2 );
spikesnippets( nspikes, : ) = snip; % - mean(snip);
end
end
end
% delete not used snippets:
spikesnippets(nspikes+1:end,:) = [];
stavgtime = [-left:dt:right];
meanrate = nspikes/length(spikes)/(stimulus(end,1)-stimulus(1,1));
% spike-triggered average:
stavg = mean( spikesnippets, 1 );
% loop over stimulus:
nstim = size( stimulus, 1 )-nw+1;
stimsnippets = zeros( nstim, nw );
stimsnippetstime = zeros( nstim, 1 );
for j=1:nstim
snip = stimulus( j:j+wr+wl, 2 );
stimsnippets(j,:) = snip; % - mean(snip);
stimsnippetstime(j) = j*dt;
end
% projection onto sta:
spikeonsta = spikesnippets * stavg';
stimonsta = stimsnippets * stavg';
% projection onto orthogonal to sta:
orthos = null( stavg );
mixcoef = rand(size( orthos, 2 ), 1);
mixcoef(length(mixcoef)/10:end) = 0.0;
staortho = orthos*mixcoef;
% stavg*staortho
staortho = staortho/(staortho'*staortho);
spikeonstaortho = spikesnippets * staortho;
stimonstaortho = stimsnippets * staortho;
% psth binned:
ratetime = 0:dt:stimulus(end,1);
% rate = zeros( size( ratetime ) );
% for k = 1:length( spikes )
% [n, ~] = hist( spikes{k}, ratetime );
% rate = rate + n/dt/length( spikes );
% end
% kernel psth:
kernelsigma = 0.001;
windowtime = -5.0*kernelsigma:dt:5.0*kernelsigma;
window = normpdf( windowtime, 0.0, kernelsigma )/length(spikes);
% plot( 1000.0*windowtime, window );
w2 = floor( length( windowtime )/2 );
kernelrate = zeros( size( ratetime ) );
for k = 1:length( spikes )
times = spikes{k};
for j = 1:length(times)
% index of spike in rate:
inx = round(times(j)/dt);
if ( inx - w2 > 0 ) & ( inx + w2 < length( kernelrate ) )
kernelrate(inx-w2:inx+w2) = kernelrate(inx-w2:inx+w2) + window;
end
end
end
if nargout == 0
% sta plot:
figure( 1 );
subplot( 3, 2, 1 )
plot( 1000.0*stavgtime, staortho, '-r', 'LineWidth', 3 )
hold on;
plot( 1000.0*stavgtime, stavg, '-b', 'LineWidth', 3 )
hold off;
xlabel( 'time [ms]' );
ylabel( 'stimulus' );
title( 'Spike-triggered average' );
legend( 'ortho', 'STA' );
% 2D scatter of projections:
subplot( 3, 2, 2 )
scatter( stimonsta, stimonstaortho, 40.0, 'b', 'filled', 'MarkerEdgeColor', 'white' )
hold on;
scatter( spikeonsta, spikeonstaortho, 20.0, 'r', 'filled', 'MarkerEdgeColor', 'white', 'LineWidth', 1.0 )
hold off;
xlabel( 'projection sta' );
ylabel( 'projection ortho' );
title( 'Projections' );
% histogram of projections onto sta:
subplot( 3, 2, 3 )
[ n, projections ] = hist( stimonsta, 50 );
bw = (projections(2)-projections(1));
pstim = n / sum(n)/bw;
bar( projections, pstim, 'b' );
hold on;
[ n, ~ ] = hist( spikeonsta, projections );
pstimspike = n / sum(n)/bw;
bar( projections, pstimspike, 'r' );
xlabel( 'projection x' );
title( 'Projection onto STA' );
xlim( [projections(1) projections(end)] );
hold off;
legend( 'p(x)', 'p(x|spikes)' );
% nonlinearity for orthogonal projections:
subplot( 3, 2, 5 )
nonlinearity = meanrate*pstimspike./pstim;
plot( projections(pstim>0.01), nonlinearity(pstim>0.01), '-r', 'LineWidth', 3 );
xlim( [projections(1) projections(end)] );
ylim( [0 1000 ])
xlabel( 'projection x' );
ylabel( 'meanrate*p(x|spikes)/p(x) [Hz]' );
title( 'Nonlinearity');
% histogram of projections onto orthogonal sta:
subplot( 3, 2, 4 )
[ n, ~] = hist( stimonstaortho, projections );
pstimortho = n / sum(n)/bw;
bar( projections, pstimortho, 'b' );
hold on;
[ n, ~ ] = hist( spikeonstaortho, projections );
pstimspikeortho = n / sum(n)/bw;
bar( projections, pstimspikeortho, 'r' );
xlim( [projections(1) projections(end)] );
xlabel( 'projection x' );
title( 'Projection onto orthogonal to STA' );
hold off;
legend( 'p(x)', 'p(x|spikes)' );
% nonlinearity for orthogonal projections:
subplot( 3, 2, 6 )
nonlinearityortho = meanrate*pstimspikeortho./pstimortho;
plot( projections(pstimortho>0.01), nonlinearityortho(pstimortho>0.01), '-r', 'LineWidth', 3 );
xlim( [projections(1) projections(end)] );
ylim( [0 1000 ])
xlabel( 'projection x' );
ylabel( 'meanrate*p(x|spikes)/p(x) [Hz]' );
title( 'Nonlinearity');
end
% stimulus reconstruction with STA:
stareconstruction = zeros( size( stimulus, 1 ), 1 );
for k = 1:length( spikes )
times = spikes{k};
for j = 1:length(times)
% index of spike in stimulus:
inx = round(times(j)/dt);
if ( inx-wl > 0 ) & ( inx+wr <= size( stimulus, 1 ) )
stareconstruction( inx-wl:inx+wr ) = stareconstruction( inx-wl:inx+wr ) + stavg';
end
end
end
stareconstruction = stareconstruction/length(spikes);
if nargout == 0
% linear stimulus reconstruction:
figure( 2 )
ax1 = subplot( 3, 1, 1 );
plot( 1000.0*stimulus(:,1), stimulus(:,2), 'g', 'LineWidth', 2 );
hold on;
plot( 1000.0*stimulus(:,1), stareconstruction, 'r', 'LineWidth', 2 );
ylabel( 'stimulus' );
hold off;
legend( 'stimulus', 'reconstruction' );
% stimulus projection onto sta:
ax2 = subplot( 3, 1, 2 );
% plot( 1000.0*ratetime, rate, '-b', 'LineWidth', 2 )
plot( 1000.0*ratetime, kernelrate-mean(kernelrate), '-b', 'LineWidth', 2 )
hold on;
plot( 1000.0*(stimsnippetstime+left), 2.0*stimonsta*meanrate, 'r', 'LineWidth', 2 );
ylabel( 'rate [Hz]' );
hold off;
legend( 'rate', 'stimulus projection on STA' );
% sta with nonlinearity:
ax3 = subplot( 3, 1, 3 );
plot( 1000.0*ratetime, kernelrate, '-b', 'LineWidth', 2 )
hold on;
xmax = max(projections(pstim>0.01));
stimonstax = stimonsta;
stimonstax( stimonstax > xmax ) = xmax;
sinx = round((stimonstax-projections(1))/bw);
sinx( sinx < 1 ) = 1;
stastimulus = nonlinearity( sinx );
plot( 1000.0*(stimsnippetstime+left), stastimulus, 'r', 'LineWidth', 2 );
legend( 'rate', 'LNP' );
xlabel( 'time [ms]' );
ylabel( 'rate [Hz]' );
hold off;
linkaxes( [ax1 ax2 ax3], 'x' );
end
end