function [ stavg, stavgtime, spikesnippets, stimsnippets, meanrate ] = sta( stimulus, spikes, left, right )
% computes the spike-triggered average
% stimulus: the stimulus as a nx2 matrix with the first column being time
% in seconds and the second column being the actual stimulus
% spikes: a cell array of vectors of spike times
% left: the time to the left of each spike
% right: the time to the right of each spike
% returns
% stavg: the spike-triggered average
% stavgtime: the corresponding time axis
% spikesnippets: the spike-triggered waveforms as a nspikes x stavgtimes
% matrix
% meanrate: the mean firing rate

    % time indices:
    dt = stimulus(2,1) - stimulus(1,1);
    wl = round( left/dt );
    wr = round( right/dt );
    nw = wl+wr+1;
    
    % total number of spikes:
    nspikes = 0;
    for k = 1:length( spikes )
        nspikes = nspikes + length( spikes{k} );
    end
    
    % loop over trials:
    spikesnippets = zeros( nspikes, nw );
    nspikes = 0;
    for k = 1:length( spikes )
        times = spikes{k};
        for j = 1:length(times)
            % index of spike in stimulus:
            inx = round(times(j)/dt);
            if ( inx-wl > 0 ) & ( inx+wr <= size( stimulus, 1 ) ) 
                nspikes = nspikes + 1;
                snip = stimulus( inx-wl:inx+wr, 2 );
                spikesnippets( nspikes, : ) = snip; % - mean(snip);
            end
        end
    end
    % delete not used snippets:
    spikesnippets(nspikes+1:end,:) = [];
    stavgtime = [-left:dt:right];
    meanrate = nspikes/length(spikes)/(stimulus(end,1)-stimulus(1,1));

    % spike-triggered average:
    stavg = mean( spikesnippets, 1 );

    % loop over stimulus:
    nstim = size( stimulus, 1 )-nw+1;
    stimsnippets = zeros( nstim, nw );
    stimsnippetstime = zeros( nstim, 1 );
    for j=1:nstim
        snip = stimulus( j:j+wr+wl, 2 );
        stimsnippets(j,:) = snip; % - mean(snip);
        stimsnippetstime(j) = j*dt;
    end

    % projection onto sta:
    spikeonsta = spikesnippets * stavg';
    stimonsta = stimsnippets * stavg';

    % projection onto orthogonal to sta:
    orthos = null( stavg );
    mixcoef = rand(size( orthos, 2 ), 1);
    mixcoef(length(mixcoef)/10:end) = 0.0;
    staortho = orthos*mixcoef;
    % stavg*staortho
    staortho = staortho/(staortho'*staortho);
    spikeonstaortho = spikesnippets * staortho;
    stimonstaortho = stimsnippets * staortho;
    
    % psth binned:
    ratetime = 0:dt:stimulus(end,1);
%     rate = zeros( size( ratetime ) );
%     for k = 1:length( spikes )
%         [n, ~] = hist( spikes{k}, ratetime );
%         rate = rate + n/dt/length( spikes );
%     end
    % kernel psth:
    kernelsigma = 0.001;
    windowtime = -5.0*kernelsigma:dt:5.0*kernelsigma;
    window = normpdf( windowtime, 0.0, kernelsigma )/length(spikes);
%     plot( 1000.0*windowtime, window );
    w2 = floor( length( windowtime )/2 );
    kernelrate = zeros( size( ratetime ) );
    for k = 1:length( spikes )
        times = spikes{k};
        for j = 1:length(times)
            % index of spike in rate:
            inx = round(times(j)/dt);
            if ( inx - w2 > 0 ) & ( inx + w2 < length( kernelrate ) )
                kernelrate(inx-w2:inx+w2) = kernelrate(inx-w2:inx+w2) + window;
            end
        end
    end

    if nargout == 0
        % sta plot:
        figure( 1 );
        subplot( 3, 2, 1 )
        plot( 1000.0*stavgtime, staortho, '-r', 'LineWidth', 3 )
        hold on;
        plot( 1000.0*stavgtime, stavg, '-b', 'LineWidth', 3 )
        hold off;
        xlabel( 'time [ms]' );
        ylabel( 'stimulus' );
        title( 'Spike-triggered average' );
        legend( 'ortho', 'STA' );

        % 2D scatter of projections:
        subplot( 3, 2, 2 )
        scatter( stimonsta, stimonstaortho, 40.0, 'b', 'filled', 'MarkerEdgeColor', 'white' )
        hold on;
        scatter( spikeonsta, spikeonstaortho, 20.0, 'r', 'filled', 'MarkerEdgeColor', 'white', 'LineWidth', 1.0 )
        hold off;
        xlabel( 'projection sta' );
        ylabel( 'projection ortho' );
        title( 'Projections' );

        % histogram of projections onto sta:
        subplot( 3, 2, 3 )
        [ n, projections ] = hist( stimonsta, 50 );
        bw = (projections(2)-projections(1));
        pstim = n / sum(n)/bw;
        bar( projections, pstim, 'b' );
        hold on;
        [ n, ~ ] = hist( spikeonsta, projections );
        pstimspike = n / sum(n)/bw;
        bar( projections, pstimspike, 'r' );
        xlabel( 'projection x' );
        title( 'Projection onto STA' );
        xlim( [projections(1) projections(end)] );
        hold off;
        legend( 'p(x)', 'p(x|spikes)' );

        % nonlinearity for orthogonal projections:
        subplot( 3, 2, 5 )
        nonlinearity = meanrate*pstimspike./pstim;
        plot( projections(pstim>0.01), nonlinearity(pstim>0.01), '-r', 'LineWidth', 3 );
        xlim( [projections(1) projections(end)] );
        ylim( [0 1000 ])
        xlabel( 'projection x' );
        ylabel( 'meanrate*p(x|spikes)/p(x) [Hz]' );
        title( 'Nonlinearity');

        % histogram of projections onto orthogonal sta:
        subplot( 3, 2, 4 )
        [ n, ~] = hist( stimonstaortho, projections );
        pstimortho = n / sum(n)/bw;
        bar( projections, pstimortho, 'b' );
        hold on;
        [ n, ~ ] = hist( spikeonstaortho, projections );
        pstimspikeortho = n / sum(n)/bw;
        bar( projections, pstimspikeortho, 'r' );
        xlim( [projections(1) projections(end)] );
        xlabel( 'projection x' );
        title( 'Projection onto orthogonal to STA' );
        hold off;
        legend( 'p(x)', 'p(x|spikes)' );

        % nonlinearity for orthogonal projections:
        subplot( 3, 2, 6 )
        nonlinearityortho = meanrate*pstimspikeortho./pstimortho;
        plot( projections(pstimortho>0.01), nonlinearityortho(pstimortho>0.01), '-r', 'LineWidth', 3 );
        xlim( [projections(1) projections(end)] );
        ylim( [0 1000 ])
        xlabel( 'projection x' );
        ylabel( 'meanrate*p(x|spikes)/p(x) [Hz]' );
        title( 'Nonlinearity');
    end
  
    % stimulus reconstruction with STA:
    stareconstruction = zeros( size( stimulus, 1 ), 1 );
    for k = 1:length( spikes )
        times = spikes{k};
        for j = 1:length(times)
            % index of spike in stimulus:
            inx = round(times(j)/dt);
            if ( inx-wl > 0 ) & ( inx+wr <= size( stimulus, 1 ) ) 
                stareconstruction( inx-wl:inx+wr ) = stareconstruction( inx-wl:inx+wr ) + stavg';
            end
        end
    end
    stareconstruction = stareconstruction/length(spikes);
    
    if nargout == 0
        % linear stimulus reconstruction:
        figure( 2 )
        ax1 = subplot( 3, 1, 1 );
        plot( 1000.0*stimulus(:,1), stimulus(:,2), 'g', 'LineWidth', 2 );
        hold on;
        plot( 1000.0*stimulus(:,1), stareconstruction, 'r', 'LineWidth', 2 );
        ylabel( 'stimulus' );
        hold off;
        legend( 'stimulus', 'reconstruction' );

        % stimulus projection onto sta:
        ax2 = subplot( 3, 1, 2 );
    %    plot( 1000.0*ratetime, rate, '-b', 'LineWidth', 2 )
        plot( 1000.0*ratetime, kernelrate-mean(kernelrate), '-b', 'LineWidth', 2 )
        hold on;
        plot( 1000.0*(stimsnippetstime+left), 2.0*stimonsta*meanrate, 'r', 'LineWidth', 2 );
        ylabel( 'rate [Hz]' );
        hold off;
        legend( 'rate', 'stimulus projection on STA' );

        % sta with nonlinearity:
        ax3 = subplot( 3, 1, 3 );
        plot( 1000.0*ratetime, kernelrate, '-b', 'LineWidth', 2 )
        hold on;
        xmax = max(projections(pstim>0.01));
        stimonstax = stimonsta;
        stimonstax( stimonstax > xmax ) = xmax;
        sinx = round((stimonstax-projections(1))/bw);
        sinx( sinx < 1 ) = 1;
        stastimulus = nonlinearity( sinx );
        plot( 1000.0*(stimsnippetstime+left), stastimulus, 'r', 'LineWidth', 2 );
        legend( 'rate', 'LNP' );
        xlabel( 'time [ms]' );
        ylabel( 'rate [Hz]' );
        hold off;
        linkaxes( [ax1 ax2 ax3], 'x' );
    end
       
end