113 lines
2.1 KiB
Matlab
113 lines
2.1 KiB
Matlab
clear
|
|
close all
|
|
|
|
%% first, plot the raw data
|
|
load('lin_regression.mat');
|
|
|
|
figure()
|
|
plot(x,y, 'o')
|
|
xlabel('Input')
|
|
ylabel('Output')
|
|
|
|
%% plot the error surface
|
|
clear
|
|
load('lin_regression.mat')
|
|
ms = -5:0.25:5;
|
|
ns = -30:1:30;
|
|
|
|
error_surf = zeros(length(ms), length(ns));
|
|
|
|
for i = 1:length(ms)
|
|
for j = 1:length(ns)
|
|
error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y);
|
|
end
|
|
end
|
|
|
|
|
|
% plot the error surface
|
|
figure()
|
|
[N,M] = meshgrid(ns, ms);
|
|
s = surface(M,N,error_surf);
|
|
xlabel('slope')
|
|
ylabel('intercept')
|
|
zlabel('error')
|
|
view(3)
|
|
% rotate(s, [1 1 0], 25 )
|
|
|
|
%% Plot the gradient at different points in the surface
|
|
clear
|
|
load('lin_regression.mat')
|
|
|
|
ms = -1:0.5:5;
|
|
ns = -10:1:10;
|
|
|
|
error_surf = zeros(length(ms), length(ns));
|
|
gradient_m = zeros(size(error_surf));
|
|
gradient_n = zeros(size(error_surf));
|
|
|
|
for i = 1:length(ms)
|
|
for j = 1:length(ns)
|
|
error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y);
|
|
grad = lsq_gradient([ms(i), ns(j)], x, y);
|
|
gradient_m(i,j) = grad(1);
|
|
gradient_n(i,j) = grad(2);
|
|
end
|
|
end
|
|
|
|
figure()
|
|
hold on
|
|
[N, M] = meshgrid(ns, ms);
|
|
surface(M,N, error_surf, 'FaceAlpha', 0.5);
|
|
contour(M,N, error_surf, 50);
|
|
quiver(M,N, gradient_m, gradient_n)
|
|
view(3)
|
|
xlabel('slope')
|
|
ylabel('intercept')
|
|
zlabel('error')
|
|
|
|
%% do the gradient descent
|
|
clear
|
|
close all
|
|
|
|
load('lin_regression.mat')
|
|
|
|
ms = -1:0.5:5;
|
|
ns = -10:1:10;
|
|
|
|
position = [-2. 10.];
|
|
gradient = [];
|
|
error = [];
|
|
eps = 0.01;
|
|
|
|
% claculate error surface
|
|
error_surf = zeros(length(ms), length(ns));
|
|
for i = 1:length(ms)
|
|
for j = 1:length(ns)
|
|
error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y);
|
|
end
|
|
end
|
|
figure()
|
|
hold on
|
|
[N, M] = meshgrid(ns, ms);
|
|
surface(M,N, error_surf, 'FaceAlpha', 0.5);
|
|
view(3)
|
|
xlabel('slope')
|
|
ylabel('intersection')
|
|
zlabel('error')
|
|
|
|
% do the descent
|
|
|
|
while isempty(gradient) || norm(gradient) > 0.1
|
|
gradient = lsq_gradient(position, x,y);
|
|
error = lsq_error(position, x, y);
|
|
plot3(position(1), position(2), error, 'o', 'color', 'red')
|
|
position = position - eps .* gradient;
|
|
pause(0.25)
|
|
end
|
|
disp('gradient descent done!')
|
|
disp(strcat('final position: ', num2str(position)))
|
|
disp(strcat('final error: ', num2str(error)))
|
|
|
|
|
|
|