43 lines
1.4 KiB
Matlab
43 lines
1.4 KiB
Matlab
function [p, ps, mses] = gradientDescent(x, y, func, p0, epsilon, threshold)
|
|
% Gradient descent for fitting a function to data pairs.
|
|
%
|
|
% Arguments: x, vector of the x-data values.
|
|
% y, vector of the corresponding y-data values.
|
|
% func, function handle func(x, p)
|
|
% p0, vector with initial parameter values
|
|
% epsilon: factor multiplying the gradient.
|
|
% threshold: minimum value for gradient
|
|
%
|
|
% Returns: p, vector with the final parameter values.
|
|
% ps: 2D-vector with all the parameter vectors traversed.
|
|
% mses: vector with the corresponding mean squared errors
|
|
|
|
p = p0;
|
|
gradient = ones(1, length(p0)) * 1000.0;
|
|
ps = [];
|
|
mses = [];
|
|
while norm(gradient) > threshold
|
|
ps = [ps, p(:)];
|
|
mses = [mses, meanSquaredError(x, y, func, p)];
|
|
gradient = meanSquaredGradient(x, y, func, p);
|
|
p = p - epsilon * gradient;
|
|
end
|
|
end
|
|
|
|
function mse = meanSquaredError(x, y, func, p)
|
|
mse = mean((y - func(x, p)).^2);
|
|
end
|
|
|
|
function gradmse = meanSquaredGradient(x, y, func, p)
|
|
gradmse = zeros(size(p, 1), size(p, 2));
|
|
h = 1e-7; % stepsize for derivatives
|
|
mse = meanSquaredError(x, y, func, p);
|
|
for i = 1:length(p) % for each coordinate ...
|
|
pi = p;
|
|
pi(i) = pi(i) + h; % displace i-th parameter
|
|
msepi = meanSquaredError(x, y, func, pi);
|
|
gradmse(i) = (msepi - mse)/h;
|
|
end
|
|
end
|
|
|