function [sta, std_sta, valid_spikes] = spikeTriggeredAverage(stimulus, spike_times, count, sampling_rate)
% Function estimates the Spike-Triggered-Average (sta).
%
% Arguments:
%           stimulus, a vector containing stimulus intensities
%           as a function of time.
%           spike_times, a vector containing the spike times 
%           in seconds.
%           count, the number of datapoints that are taken around
%           the spike times.
%           sampling_rate, the sampling rate of the stimulus.
%
% Returns: 
%           the sta, a vector containing the staandard deviation and 
%           the number of spikes taken into account.

snippets = zeros(numel(spike_times), 2*count);
valid_spikes = 1;
for i = 1:numel(spike_times)
    t = spike_times(i);
    index = round(t*sampling_rate);
    if index <= count || (index + count) > length(stimulus)
        continue
    end
    snippets(valid_spikes,:) = stimulus(index-count:index+count-1);
    valid_spikes = valid_spikes + 1;
end

snippets(valid_spikes:end,:) = [];

sta = mean(snippets, 1);
std_sta = std(snippets,[],1);