function [sta, std_sta, n_spikes] = spikeTriggeredAverage(stimulus, spikes, count, deltat)
% Estimate the spike-triggered-average (STA).
%
% [sta, std_sta, n_spikes] = spikeTriggeredAverage(stimulus, spikes, count, deltat)
%
% Arguments:
%     stimulus: vector of stimulus intensities as a function of time.
%     spikes  : vector with spike times in seconds.
%     count   : number of datapoints that are taken around the spike times.
%     deltat  : the time step of the stimulus in seconds.
%
% Returns: 
%     sta     : vector with the STA.
%     std_sta : standard deviation of the STA.
%     n_spikes: number of spikes contained in STA.

    snippets = zeros(numel(spikes), 2*count);
    n_spikes = 0;
    for i = 1:numel(spikes)
        t = spikes(i);
        index = round(t/deltat);
        if index <= count || (index + count) > length(stimulus)
            continue
        end
        snippets(n_spikes,:) = stimulus(index-count:index+count-1);
        n_spikes = n_spikes + 1;
    end
    snippets(n_spikes+1:end,:) = [];
    sta = mean(snippets, 1);
    std_sta = std(snippets,[],1);
end