close all
datapath = '../data/';
files = dir(strcat(datapath, 'unit*.mat'));
% for file = files'
% 	a = load(strcat(datapath, file.name));
% 	spikes = a.spikes;
% 	angles = a.angles;
% 	figure()
% 	for k = 1:size(spikes, 1)
% 		subplot(3, 4, k)
% 		spikeraster(spikes(k,:), -0.2, 0.6);
% 	end
% end


%% tuning curves:
close all
cosine = @(p,xdata)0.5*p(1).*(1.0+cos(2.0*pi*(xdata-p(2))/180.0));
files = dir(strcat(datapath, 'unit*.mat'));
phases = zeros(length(files), 1);
gains = zeros(length(files), 1);
figure()
for j = 1:length(files)
	file = files(j);
	a = load(strcat(datapath, file.name));
	spikes = a.spikes;
	angles = a.angles;
	rates = zeros(size(spikes, 1), 1);
	for k = 1:size(spikes, 1)
		r = firingrate(spikes(k,:), 0.0, 0.2);
		rates(k) = r;
	end
	[mr, maxi] = max(rates);
	p0 = [mr, angles(maxi)];
	%p = p0;
	p = lsqcurvefit(cosine, p0, angles, rates');
	phase = p(2);
    if phase > 180.0
        phase = phase - 180.0;
    end
    if phase < 0.0
        phase = phase + 180.0;
    end
    phases(j) = phase;
    gains(j) = p(1);
	subplot(2, 3, j);
	plot(angles, rates, 'b');
	hold on;
    a = 0:0.1:180;
	plot(a, cosine(p, a), 'r');
	hold off;
	xlim([0.0 180.0])
	ylim([0.0 50.0])
	title(sprintf('unit %d', j))
end


%% read out:
a = load(strcat(datapath, 'population04.mat'));
spikes = a.spikes;
angle = a.angle;
% unitphases = a.phases*180.0;
% unitphases(unitphases>180.0) = unitphases(unitphases>180.0) - 180.0;
figure();
subplot(2, 2, 1);
angleestimates1 = zeros(size(spikes, 2), 1);
angleestimates2 = zeros(size(spikes, 2), 1);
angleestimates3 = zeros(size(spikes, 2), 1);
[x, inx] = sort(phases);
% loop over trials:
for j = 1:size(spikes, 2)
    rates = zeros(size(spikes, 1), 1);
    for k = 1:size(spikes, 1)
        r = firingrate(spikes(k, j), 0.0, 0.2);
        rates(k) = r;
    end
    plot(phases(inx), rates(inx), '-o');
    hold on;
    angleestimates1(j) = popvecangle(phases, rates);
    [m, i] = max(rates);
    angleestimates2(j) = phases(i);
    angleestimates3(j) = maxlikelihoodangle(phases, gains, rates);
end
xlabel('preferred angle')
ylabel('firing rate')
hold off;
subplot(2, 2, 2);
hist(angleestimates1);
xlabel('population vector angle')
subplot(2, 2, 3);
hist(angleestimates2);
xlabel('max. rate angle')
subplot(2, 2, 4);
hist(angleestimates3);
xlabel('max. likelihood angle')
angle
mean(angleestimates1)
mean(angleestimates2)
mean(angleestimates3)


%% read out robustness:
files = dir(strcat(datapath, 'population*.mat'));
angles = zeros(length(files), 1);
e1mm = zeros(length(files), 1);
e2mm = zeros(length(files), 1);
e3mm = zeros(length(files), 1);
e1sm = zeros(length(files), 1);
e1ss = zeros(length(files), 1);
e2sm = zeros(length(files), 1);
e2ss = zeros(length(files), 1);
e3sm = zeros(length(files), 1);
e3ss = zeros(length(files), 1);
for i = 1:length(files)
	file = files(i);
	a = load(strcat(datapath, file.name));
    spikes = a.spikes;
    angle = a.angle;
    % multi trial estimates:
    rates = zeros(size(spikes, 1), 1);
    for k = 1:size(spikes, 1)
        r = zeros(size(spikes, 2), 1);
        for j = 1:size(spikes, 2)
            r(j) = firingrate(spikes(k, j), 0.0, 0.2);
        end
        rates(k) = mean(r);
    end
    e1mm(i) = popvecangle(phases, rates);
    [m, inx] = max(rates);
    e2mm(i) = phases(inx);
    e3mm(i) = maxlikelihoodangle(phases, gains, rates);
    % single trial estimates:
    angleestimates1 = zeros(size(spikes, 2), 1);
    angleestimates2 = zeros(size(spikes, 2), 1);
    angleestimates3 = zeros(size(spikes, 2), 1);
    for j = 1:size(spikes, 2)
        rates = zeros(size(spikes, 1), 1);
        for k = 1:size(spikes, 1)
            r = firingrate(spikes(k, j), 0.0, 0.2);
            rates(k) = r;
        end
        angleestimates1(j) = popvecangle(phases, rates);
        [m, inx] = max(rates);
        angleestimates2(j) = phases(inx);
        angleestimates3(j) = maxlikelihoodangle(phases, gains, rates);
    end
    angles(i) = angle;
    e1sm(i) = mean(angleestimates1);
    e1ss(i) = std(angleestimates1);
    e2sm(i) = mean(angleestimates2);
    e2ss(i) = std(angleestimates2);
    e3sm(i) = mean(angleestimates3);
    e3ss(i) = std(angleestimates3);
end
x = 0:180.0;
figure();
subplot(1, 3, 1);
hold on;
plot(x, x, 'k');
scatter(angles, e1mm);
xlabel('stimulus angle')
ylabel('estimated angle (population vector)')
subplot(1, 3, 2);
hold on;
plot(x, x, 'k');
scatter(angles, e2mm);
xlabel('stimulus angle')
ylabel('estimated angle (maximum firing rate)')
subplot(1, 3, 3);
hold on;
plot(x, x, 'k');
scatter(angles, e3mm);
xlabel('stimulus angle')
ylabel('estimated angle (maximum likelihood)')
figure();
subplot(1, 3, 1);
hold on;
plot(x, x, 'k');
scatter(angles, e1sm);
xlabel('stimulus angle')
ylabel('estimated angle (population vector)')
subplot(1, 3, 2);
hold on;
plot(x, x, 'k');
scatter(angles, e2sm);
xlabel('stimulus angle')
ylabel('estimated angle (maximum firing rate)')
subplot(1, 3, 3);
hold on;
plot(x, x, 'k');
scatter(angles, e3sm);
xlabel('stimulus angle')
ylabel('estimated angle (maximum likelihood)')