From 40dbc3cfb93821bcf48b00bc7aa698821ecd1dff Mon Sep 17 00:00:00 2001
From: Jan Benda <jan.benda@uni-tuebingen.de>
Date: Tue, 30 Jan 2018 14:37:28 +0100
Subject: [PATCH] added maximum likelihood solution

---
 .../solution/maxlikelihoodangle.m             |  23 ++++
 .../solution/populationvector.m               | 124 +++++++++++++-----
 2 files changed, 116 insertions(+), 31 deletions(-)
 create mode 100644 projects/project_populationvector/solution/maxlikelihoodangle.m

diff --git a/projects/project_populationvector/solution/maxlikelihoodangle.m b/projects/project_populationvector/solution/maxlikelihoodangle.m
new file mode 100644
index 0000000..339e561
--- /dev/null
+++ b/projects/project_populationvector/solution/maxlikelihoodangle.m
@@ -0,0 +1,23 @@
+function angle = maxlikelihoodangle(phases, gains, rates)
+% maximum likelihood estimation of orientation
+cosine = @(g,p,xdata)0.5*g.*(1.0+cos(2.0*pi*(xdata-p)/180.0));
+angels = 0:1.0:180.0;
+loglikelihoods = zeros(length(phases), length(angels));
+for i=1:length(phases)
+    r = cosine(gains(i), phases(i), angels);
+    %%loglikelihoods(i, :) = exp(-0.5*((rates(i)-r)./(0.25*r)).^2.0)./sqrt(2.0*pi*(0.25*r).^2.0);
+    %loglikelihoods(i, :) = log(exp(-0.5*((rates(i)-r)./(0.25*r)).^2.0)./sqrt(2.0*pi*(0.25*r).^2.0));
+    loglikelihoods(i, :) = -0.5*((rates(i)-r)./(0.25*r)).^2.0 - 0.5*log(2.0*pi*(0.25*r).^2.0);
+end
+loglikelihood = sum(loglikelihoods, 1);
+[m i] = max(loglikelihood);
+angle = angels(i);
+% plot(angels, loglikelihood);
+% hold on;
+% plot([angle angle], [-500 0], 'k')
+% hold off;
+% xlabel('angle');
+% ylabel('likelihood');
+% ylim([-500, 0]);
+% pause( 0.2 );
+end
diff --git a/projects/project_populationvector/solution/populationvector.m b/projects/project_populationvector/solution/populationvector.m
index 9e8d475..afcbcef 100644
--- a/projects/project_populationvector/solution/populationvector.m
+++ b/projects/project_populationvector/solution/populationvector.m
@@ -1,23 +1,24 @@
 close all
 datapath = '../data/';
 files = dir(strcat(datapath, 'unit*.mat'));
-for file = files'
-	a = load(strcat(datapath, file.name));
-	spikes = a.spikes;
-	angles = a.angles;
-	figure()
-	for k = 1:size(spikes, 1)
-		subplot(3, 4, k)
-		spikeraster(spikes(k,:), -0.2, 0.6);
-	end
-end
+% for file = files'
+% 	a = load(strcat(datapath, file.name));
+% 	spikes = a.spikes;
+% 	angles = a.angles;
+% 	figure()
+% 	for k = 1:size(spikes, 1)
+% 		subplot(3, 4, k)
+% 		spikeraster(spikes(k,:), -0.2, 0.6);
+% 	end
+% end
 
 
 %% tuning curves:
 close all
-cosine = @(p,xdata)0.5*p(1).*(1.0+cos(2.0*pi*(xdata/180.0-p(2))));
+cosine = @(p,xdata)0.5*p(1).*(1.0+cos(2.0*pi*(xdata-p(2))/180.0));
 files = dir(strcat(datapath, 'unit*.mat'));
 phases = zeros(length(files), 1);
+gains = zeros(length(files), 1);
 figure()
 for j = 1:length(files)
 	file = files(j);
@@ -30,10 +31,10 @@ for j = 1:length(files)
 		rates(k) = r;
 	end
 	[mr, maxi] = max(rates);
-	p0 = [mr, angles(maxi)/180.0-0.5];
+	p0 = [mr, angles(maxi)];
 	%p = p0;
 	p = lsqcurvefit(cosine, p0, angles, rates');
-	phase = p(2)*180.0;
+	phase = p(2);
     if phase > 180.0
         phase = phase - 180.0;
     end
@@ -41,10 +42,12 @@ for j = 1:length(files)
         phase = phase + 180.0;
     end
     phases(j) = phase;
+    gains(j) = p(1);
 	subplot(2, 3, j);
 	plot(angles, rates, 'b');
 	hold on;
-	plot(angles, cosine(p, angles), 'r');
+    a = 0:0.1:180;
+	plot(a, cosine(p, a), 'r');
 	hold off;
 	xlim([0.0 180.0])
 	ylim([0.0 50.0])
@@ -56,12 +59,13 @@ end
 a = load(strcat(datapath, 'population04.mat'));
 spikes = a.spikes;
 angle = a.angle;
-unitphases = a.phases*180.0;
-unitphases(unitphases>180.0) = unitphases(unitphases>180.0) - 180.0;
+% unitphases = a.phases*180.0;
+% unitphases(unitphases>180.0) = unitphases(unitphases>180.0) - 180.0;
 figure();
-subplot(1, 3, 1);
+subplot(2, 2, 1);
 angleestimates1 = zeros(size(spikes, 2), 1);
 angleestimates2 = zeros(size(spikes, 2), 1);
+angleestimates3 = zeros(size(spikes, 2), 1);
 [x, inx] = sort(phases);
 % loop over trials:
 for j = 1:size(spikes, 2)
@@ -75,35 +79,60 @@ for j = 1:size(spikes, 2)
     angleestimates1(j) = popvecangle(phases, rates);
     [m, i] = max(rates);
     angleestimates2(j) = phases(i);
+    angleestimates3(j) = maxlikelihoodangle(phases, gains, rates);
 end
 xlabel('preferred angle')
 ylabel('firing rate')
 hold off;
-subplot(1, 3, 2);
+subplot(2, 2, 2);
 hist(angleestimates1);
 xlabel('population vector angle')
-subplot(1, 3, 3);
+subplot(2, 2, 3);
 hist(angleestimates2);
 xlabel('max. rate angle')
+subplot(2, 2, 4);
+hist(angleestimates3);
+xlabel('max. likelihood angle')
 angle
 mean(angleestimates1)
 mean(angleestimates2)
+mean(angleestimates3)
 
 
 %% read out robustness:
 files = dir(strcat(datapath, 'population*.mat'));
 angles = zeros(length(files), 1);
-e1m = zeros(length(files), 1);
-e1s = zeros(length(files), 1);
-e2m = zeros(length(files), 1);
-e2s = zeros(length(files), 1);
+e1mm = zeros(length(files), 1);
+e2mm = zeros(length(files), 1);
+e3mm = zeros(length(files), 1);
+e1sm = zeros(length(files), 1);
+e1ss = zeros(length(files), 1);
+e2sm = zeros(length(files), 1);
+e2ss = zeros(length(files), 1);
+e3sm = zeros(length(files), 1);
+e3ss = zeros(length(files), 1);
 for i = 1:length(files)
 	file = files(i);
 	a = load(strcat(datapath, file.name));
     spikes = a.spikes;
     angle = a.angle;
+    % multi trial estimates:
+    rates = zeros(size(spikes, 1), 1);
+    for k = 1:size(spikes, 1)
+        r = zeros(size(spikes, 2), 1);
+        for j = 1:size(spikes, 2)
+            r(j) = firingrate(spikes(k, j), 0.0, 0.2);
+        end
+        rates(k) = mean(r);
+    end
+    e1mm(i) = popvecangle(phases, rates);
+    [m, inx] = max(rates);
+    e2mm(i) = phases(inx);
+    e3mm(i) = maxlikelihoodangle(phases, gains, rates);
+    % single trial estimates:
     angleestimates1 = zeros(size(spikes, 2), 1);
     angleestimates2 = zeros(size(spikes, 2), 1);
+    angleestimates3 = zeros(size(spikes, 2), 1);
     for j = 1:size(spikes, 2)
         rates = zeros(size(spikes, 1), 1);
         for k = 1:size(spikes, 1)
@@ -113,19 +142,52 @@ for i = 1:length(files)
         angleestimates1(j) = popvecangle(phases, rates);
         [m, inx] = max(rates);
         angleestimates2(j) = phases(inx);
+        angleestimates3(j) = maxlikelihoodangle(phases, gains, rates);
     end
     angles(i) = angle;
-    e1m(i) = mean(angleestimates1);
-    e1s(i) = std(angleestimates1);
-    e2m(i) = mean(angleestimates2);
-    e2s(i) = std(angleestimates2);
+    e1sm(i) = mean(angleestimates1);
+    e1ss(i) = std(angleestimates1);
+    e2sm(i) = mean(angleestimates2);
+    e2ss(i) = std(angleestimates2);
+    e3sm(i) = mean(angleestimates3);
+    e3ss(i) = std(angleestimates3);
 end
+x = 0:180.0;
 figure();
-subplot(1, 2, 1);
-scatter(angles, e1m);
+subplot(1, 3, 1);
+hold on;
+plot(x, x, 'k');
+scatter(angles, e1mm);
 xlabel('stimulus angle')
 ylabel('estimated angle (population vector)')
-subplot(1, 2, 2);
-scatter(angles, e2m);
+subplot(1, 3, 2);
+hold on;
+plot(x, x, 'k');
+scatter(angles, e2mm);
 xlabel('stimulus angle')
 ylabel('estimated angle (maximum firing rate)')
+subplot(1, 3, 3);
+hold on;
+plot(x, x, 'k');
+scatter(angles, e3mm);
+xlabel('stimulus angle')
+ylabel('estimated angle (maximum likelihood)')
+figure();
+subplot(1, 3, 1);
+hold on;
+plot(x, x, 'k');
+scatter(angles, e1sm);
+xlabel('stimulus angle')
+ylabel('estimated angle (population vector)')
+subplot(1, 3, 2);
+hold on;
+plot(x, x, 'k');
+scatter(angles, e2sm);
+xlabel('stimulus angle')
+ylabel('estimated angle (maximum firing rate)')
+subplot(1, 3, 3);
+hold on;
+plot(x, x, 'k');
+scatter(angles, e3sm);
+xlabel('stimulus angle')
+ylabel('estimated angle (maximum likelihood)')