[regression] gradient descent code

This commit is contained in:
2020-12-19 12:20:18 +01:00
parent 4b624fe981
commit 7518f9dd47
3 changed files with 98 additions and 28 deletions

View File

@@ -0,0 +1,25 @@
function [c, cs, mses] = gradientDescentCubic(x, y, c0, epsilon, threshold)
% Gradient descent for fitting a cubic relation.
%
% Arguments: x, vector of the x-data values.
% y, vector of the corresponding y-data values.
% c0, initial value for the parameter c.
% epsilon: factor multiplying the gradient.
% threshold: minimum value for gradient
%
% Returns: c, the final value of the c-parameter.
% cs: vector with all the c-values traversed.
% mses: vector with the corresponding mean squared errors
c = c0;
gradient = 1000.0;
cs = [];
mses = [];
count = 1;
while abs(gradient) > threshold
cs(count) = c;
mses(count) = meanSquaredErrorCubic(x, y, c);
gradient = meanSquaredGradientCubic(x, y, c);
c = c - epsilon * gradient;
count = count + 1;
end
end

View File

@@ -0,0 +1,29 @@
meansquarederrorline % generate data
c0 = 2.0;
eps = 0.0001;
thresh = 0.1;
[cest, cs, mses] = gradientDescentCubic(x, y, c0, eps, thresh);
subplot(2, 2, 1); % top left panel
hold on;
plot(cs, '-o');
plot([1, length(cs)], [c, c], 'k'); % line indicating true c value
hold off;
xlabel('Iteration');
ylabel('C');
subplot(2, 2, 3); % bottom left panel
plot(mses, '-o');
xlabel('Iteration steps');
ylabel('MSE');
subplot(1, 2, 2); % right panel
hold on;
% generate x-values for plottig the fit:
xx = min(x):0.01:max(x);
yy = cest * xx.^3;
plot(xx, yy, 'displayname', 'fit');
plot(x, y, 'o', 'displayname', 'data'); % plot original data
xlabel('Size [m]');
ylabel('Weight [kg]');
legend("location", "northwest");
pause