[regression] gradient descent code
This commit is contained in:
25
regression/code/gradientDescentCubic.m
Normal file
25
regression/code/gradientDescentCubic.m
Normal file
@@ -0,0 +1,25 @@
|
||||
function [c, cs, mses] = gradientDescentCubic(x, y, c0, epsilon, threshold)
|
||||
% Gradient descent for fitting a cubic relation.
|
||||
%
|
||||
% Arguments: x, vector of the x-data values.
|
||||
% y, vector of the corresponding y-data values.
|
||||
% c0, initial value for the parameter c.
|
||||
% epsilon: factor multiplying the gradient.
|
||||
% threshold: minimum value for gradient
|
||||
%
|
||||
% Returns: c, the final value of the c-parameter.
|
||||
% cs: vector with all the c-values traversed.
|
||||
% mses: vector with the corresponding mean squared errors
|
||||
c = c0;
|
||||
gradient = 1000.0;
|
||||
cs = [];
|
||||
mses = [];
|
||||
count = 1;
|
||||
while abs(gradient) > threshold
|
||||
cs(count) = c;
|
||||
mses(count) = meanSquaredErrorCubic(x, y, c);
|
||||
gradient = meanSquaredGradientCubic(x, y, c);
|
||||
c = c - epsilon * gradient;
|
||||
count = count + 1;
|
||||
end
|
||||
end
|
||||
29
regression/code/plotgradientdescentcubic.m
Normal file
29
regression/code/plotgradientdescentcubic.m
Normal file
@@ -0,0 +1,29 @@
|
||||
meansquarederrorline % generate data
|
||||
|
||||
c0 = 2.0;
|
||||
eps = 0.0001;
|
||||
thresh = 0.1;
|
||||
[cest, cs, mses] = gradientDescentCubic(x, y, c0, eps, thresh);
|
||||
|
||||
subplot(2, 2, 1); % top left panel
|
||||
hold on;
|
||||
plot(cs, '-o');
|
||||
plot([1, length(cs)], [c, c], 'k'); % line indicating true c value
|
||||
hold off;
|
||||
xlabel('Iteration');
|
||||
ylabel('C');
|
||||
subplot(2, 2, 3); % bottom left panel
|
||||
plot(mses, '-o');
|
||||
xlabel('Iteration steps');
|
||||
ylabel('MSE');
|
||||
subplot(1, 2, 2); % right panel
|
||||
hold on;
|
||||
% generate x-values for plottig the fit:
|
||||
xx = min(x):0.01:max(x);
|
||||
yy = cest * xx.^3;
|
||||
plot(xx, yy, 'displayname', 'fit');
|
||||
plot(x, y, 'o', 'displayname', 'data'); % plot original data
|
||||
xlabel('Size [m]');
|
||||
ylabel('Weight [kg]');
|
||||
legend("location", "northwest");
|
||||
pause
|
||||
Reference in New Issue
Block a user