import numpy as np
import matplotlib.pyplot as plt
from plotstyle import *

def create_data():
    m = 0.75
    n=  -40
    x = np.arange(10.,110., 2.5)
    y = m * x + n;
    rng = np.random.RandomState(37281)
    noise = rng.randn(len(x))*15
    y += noise
    return x, y, m, n

    
def plot_data(ax, x, y):
    ax.plot(x, y, **psA)
    ax.set_xlabel('Input x')
    ax.set_ylabel('Output y')
    ax.set_xlim(0, 120)
    ax.set_ylim(-80, 80)
    ax.set_xticks(np.arange(0,121, 40))
    ax.set_yticks(np.arange(-80,81, 40))
    

def plot_data_slopes(ax, x, y, m, n):
    ax.plot(x, y, **psA)
    xx = np.asarray([2, 118])
    for i in np.linspace(0.3*m, 2.0*m, 5):
        ax.plot(xx, i*xx+n, **lsBm)
    ax.set_xlabel('Input x')
    #ax.set_ylabel('Output y')
    ax.set_xlim(0, 120)
    ax.set_ylim(-80, 80)
    ax.set_xticks(np.arange(0,121, 40))
    ax.set_yticks(np.arange(-80,81, 40))


def plot_data_intercepts(ax, x, y, m, n):
    ax.plot(x, y, **psA)
    xx = np.asarray([2, 118])
    for i in np.linspace(n-1*n, n+1*n, 5):
        ax.plot(xx, m*xx + i, **lsBm)
    ax.set_xlabel('Input x')
    #ax.set_ylabel('Output y')
    ax.set_xlim(0, 120)
    ax.set_ylim(-80, 80)
    ax.set_xticks(np.arange(0,121, 40))
    ax.set_yticks(np.arange(-80,81, 40))


if __name__ == "__main__":
    x, y, m, n = create_data()
    fig, axs = plt.subplots(1, 3)
    fig.subplots_adjust(wspace=0.5, **adjust_fs(fig, left=6.0, right=1.5))
    plot_data(axs[0], x, y)
    plot_data_slopes(axs[1], x, y, m, n)
    plot_data_intercepts(axs[2], x, y, m, n)
    fig.savefig("lin_regress.pdf")
    plt.close()