[regression] finished n-dim minimization

This commit is contained in:
2020-12-19 21:53:19 +01:00
parent 891515caf8
commit bfb2f66de2
4 changed files with 166 additions and 42 deletions

View File

@@ -0,0 +1,41 @@
function [p, ps, mses] = gradientDescentPower(x, y, p0, epsilon, threshold)
% Gradient descent for fitting a power-law.
%
% Arguments: x, vector of the x-data values.
% y, vector of the corresponding y-data values.
% p0, vector with initial values for c and alpha.
% epsilon: factor multiplying the gradient.
% threshold: minimum value for gradient
%
% Returns: p, vector with the final parameter values.
% ps: 2D-vector with all the parameter tuples traversed.
% mses: vector with the corresponding mean squared errors
p = p0;
gradient = ones(1, length(p0)) * 1000.0;
ps = [];
mses = [];
while norm(gradient) > threshold
ps = [ps, p(:)];
mses = [mses, meanSquaredErrorPower(x, y, p)];
gradient = meanSquaredGradientPower(x, y, p);
p = p - epsilon * gradient;
end
end
function mse = meanSquaredErrorPower(x, y, p)
mse = mean((y - p(1)*x.^p(2)).^2);
end
function gradmse = meanSquaredGradientPower(x, y, p)
gradmse = zeros(size(p, 1), size(p, 2));
h = 1e-5; % stepsize for derivatives
mse = meanSquaredErrorPower(x, y, p);
for i = 1:length(p) % for each coordinate ...
pi = p;
pi(i) = pi(i) + h; % displace i-th parameter
msepi = meanSquaredErrorPower(x, y, pi);
gradmse(i) = (msepi - mse)/h;
end
end

View File

@@ -0,0 +1,32 @@
meansquarederrorline; % generate data
p0 = [2.0, 1.0];
eps = 0.00001;
thresh = 50.0;
[pest, ps, mses] = gradientDescentPower(x, y, p0, eps, thresh);
pest
subplot(2, 2, 1); % top left panel
hold on;
plot(ps(1,:), ps(2,:), '.');
plot(ps(1,end), ps(2,end), 'og');
plot(c, 3.0, 'or'); % dot indicating true parameter values
hold off;
xlabel('Iteration');
ylabel('C');
subplot(2, 2, 3); % bottom left panel
plot(mses, '-o');
xlabel('Iteration steps');
ylabel('MSE');
subplot(1, 2, 2); % right panel
hold on;
% generate x-values for plottig the fit:
xx = min(x):0.01:max(x);
yy = pest(1) * xx.^pest(2);
plot(xx, yy);
plot(x, y, 'o'); % plot original data
xlabel('Size [m]');
ylabel('Weight [kg]');
legend('fit', 'data', 'location', 'northwest');
pause