clear
close all

%% first, plot the raw data
load('lin_regression.mat');

figure()
plot(x,y, 'o')
xlabel('Input')
ylabel('Output')

%% plot the error surface
clear
load('lin_regression.mat')
ms = -5:0.25:5;
ns = -30:1:30;

error_surf = zeros(length(ms), length(ns));

for i = 1:length(ms)
    for j = 1:length(ns)
        error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y);
    end
end


% plot the error surface
figure()
[N,M] = meshgrid(ns, ms);
s = surface(M,N,error_surf);
xlabel('slope')
ylabel('intercept')
zlabel('error')
view(3)
% rotate(s, [1 1 0], 25 )

%% Plot the gradient at different points in the surface
clear
load('lin_regression.mat')

ms = -1:0.5:5;
ns = -10:1:10;

error_surf = zeros(length(ms), length(ns));
gradient_m = zeros(size(error_surf));
gradient_n = zeros(size(error_surf));

for i = 1:length(ms)
    for j = 1:length(ns)
        error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y);
        grad = lsq_gradient([ms(i), ns(j)], x, y);
        gradient_m(i,j) = grad(1);
        gradient_n(i,j) = grad(2);
    end
end

figure()
hold on
[N, M] = meshgrid(ns, ms);
surface(M,N, error_surf, 'FaceAlpha', 0.5);
contour(M,N, error_surf, 50);
quiver(M,N, gradient_m, gradient_n)
view(3)
xlabel('slope')
ylabel('intercept')
zlabel('error')

%% do the gradient descent
clear 
close all

load('lin_regression.mat')

ms = -1:0.5:5;
ns = -10:1:10;

position = [-2. 10.];
gradient = [];
error = [];
eps = 0.01;

% claculate error surface
error_surf = zeros(length(ms), length(ns));
for i = 1:length(ms)
    for j = 1:length(ns)
        error_surf(i,j) = lsq_error([ms(i), ns(j)], x, y);
    end
end
figure()
hold on
[N, M] = meshgrid(ns, ms);
surface(M,N, error_surf, 'FaceAlpha', 0.5);
view(3)
xlabel('slope')
ylabel('intersection')
zlabel('error')

% do the descent

while isempty(gradient) || norm(gradient) > 0.1
    gradient = lsq_gradient(position, x,y);
    error = lsq_error(position, x, y);
    plot3(position(1), position(2), error, 'o', 'color', 'red')
    position = position - eps .* gradient; 
    pause(0.25)
end
disp('gradient descent done!')
disp(strcat('final position: ', num2str(position)))
disp(strcat('final error: ', num2str(error)))