function [st_avg, std_sta, valid_spikes]= sta(stimulus, spike_times, count, sampling_rate)

snippets = zeros(numel(spike_times), 2*count);
valid_spikes = 1;
for i = 1:numel(spike_times)
    t = spike_times(i);
    index = round(t*sampling_rate);
    if index <= count || (index + count) > length(stimulus)
        continue
    end
    snippets(valid_spikes,:) = stimulus(index-count:index+count-1);
    valid_spikes = valid_spikes + 1;
end

snippets(valid_spikes:end,:) = [];

st_avg = mean(snippets, 1);
std_sta = std(snippets,[],1);