clear close all %% first, plot the raw data load('lin_regression.mat'); figure() plot(x,y, 'o') xlabel('Input') ylabel('Output') %% plot the error surface clear load('lin_regression.mat') ms = -5:0.25:5; ns = -30:1:30; error_surf = zeros(length(ms), length(ns)); for i = 1:length(ms) for j = 1:length(ns) error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y); end end % plot the error surface figure() [N,M] = meshgrid(ns, ms); s = surface(M,N,error_surf); xlabel('slope') ylabel('intercept') zlabel('error') view(3) % rotate(s, [1 1 0], 25 ) %% Plot the gradient at different points in the surface clear load('lin_regression.mat') ms = -1:0.5:5; ns = -10:1:10; error_surf = zeros(length(ms), length(ns)); gradient_m = zeros(size(error_surf)); gradient_n = zeros(size(error_surf)); for i = 1:length(ms) for j = 1:length(ns) error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y); grad = lsq_gradient([ms(i), ns(j)], x, y); gradient_m(i,j) = grad(1); gradient_n(i,j) = grad(2); end end figure() hold on [N, M] = meshgrid(ns, ms); surface(M,N, error_surf, 'FaceAlpha', 0.5); contour(M,N, error_surf, 50); quiver(M,N, gradient_m, gradient_n) view(3) xlabel('slope') ylabel('intercept') zlabel('error') %% do the gradient descent clear close all load('lin_regression.mat') ms = -1:0.5:5; ns = -10:1:10; position = [-2. 10.]; gradient = []; error = []; eps = 0.01; % claculate error surface error_surf = zeros(length(ms), length(ns)); for i = 1:length(ms) for j = 1:length(ns) error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y); end end figure() hold on [N, M] = meshgrid(ns, ms); surface(M,N, error_surf, 'FaceAlpha', 0.5); view(3) xlabel('slope') ylabel('intersection') zlabel('error') % do the descent while isempty(gradient) || norm(gradient) > 0.1 gradient = lsq_gradient(position, x,y); error = lsq_error(position, x, y); plot3(position(1), position(2), error, 'o', 'color', 'red') position = position - eps .* gradient; pause(0.25) end disp('gradient descent done!') disp(strcat('final position: ', num2str(position))) disp(strcat('final error: ', num2str(error)))