84 lines
2.5 KiB
Python
84 lines
2.5 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
def create_data():
|
|
m = 0.75
|
|
n= -40
|
|
x = np.arange(10.,110., 2.5)
|
|
y = m * x + n;
|
|
rng = np.random.RandomState(37281)
|
|
noise = rng.randn(len(x))*15
|
|
y += noise
|
|
return x, y, m, n
|
|
|
|
|
|
def plot_data(ax, x, y):
|
|
ax.scatter(x, y, marker='o', color='b', s=40)
|
|
ax.spines["right"].set_visible(False)
|
|
ax.spines["top"].set_visible(False)
|
|
ax.yaxis.set_ticks_position('left')
|
|
ax.xaxis.set_ticks_position('bottom')
|
|
ax.tick_params(direction="out", width=1.25)
|
|
ax.tick_params(direction="out", width=1.25)
|
|
ax.set_xlabel('Input x')
|
|
ax.set_ylabel('Output y')
|
|
ax.set_xlim(0, 120)
|
|
ax.set_ylim(-80, 80)
|
|
ax.set_xticks(np.arange(0,121, 40))
|
|
ax.set_yticks(np.arange(-80,81, 40))
|
|
|
|
|
|
def plot_data_slopes(ax, x, y, m, n):
|
|
ax.scatter(x, y, marker='o', color='b', s=40)
|
|
xx = np.asarray([2, 118])
|
|
for i in np.linspace(0.3*m, 2.0*m, 5):
|
|
ax.plot(xx, i*xx+n, color='#CC0000', lw=2)
|
|
ax.spines["right"].set_visible(False)
|
|
ax.spines["top"].set_visible(False)
|
|
ax.yaxis.set_ticks_position('left')
|
|
ax.xaxis.set_ticks_position('bottom')
|
|
ax.tick_params(direction="out", width=1.25)
|
|
ax.tick_params(direction="out", width=1.25)
|
|
ax.set_xlabel('Input x')
|
|
#ax.set_ylabel('Output y')
|
|
ax.set_xlim(0, 120)
|
|
ax.set_ylim(-80, 80)
|
|
ax.set_xticks(np.arange(0,121, 40))
|
|
ax.set_yticks(np.arange(-80,81, 40))
|
|
|
|
|
|
def plot_data_intercepts(ax, x, y, m, n):
|
|
ax.scatter(x, y, marker='o', color='b', s=40)
|
|
xx = np.asarray([2, 118])
|
|
for i in np.linspace(n-1*n, n+1*n, 5):
|
|
ax.plot(xx, m*xx + i, color='#CC0000', lw=2)
|
|
ax.spines["right"].set_visible(False)
|
|
ax.spines["top"].set_visible(False)
|
|
ax.yaxis.set_ticks_position('left')
|
|
ax.xaxis.set_ticks_position('bottom')
|
|
ax.tick_params(direction="out", width=1.25)
|
|
ax.tick_params(direction="out", width=1.25)
|
|
ax.set_xlabel('Input x')
|
|
#ax.set_ylabel('Output y')
|
|
ax.set_xlim(0, 120)
|
|
ax.set_ylim(-80, 80)
|
|
ax.set_xticks(np.arange(0,121, 40))
|
|
ax.set_yticks(np.arange(-80,81, 40))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
x, y, m, n = create_data()
|
|
plt.xkcd()
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(1, 3, 1)
|
|
plot_data(ax, x, y)
|
|
ax = fig.add_subplot(1, 3, 2)
|
|
plot_data_slopes(ax, x, y, m, n)
|
|
ax = fig.add_subplot(1, 3, 3)
|
|
plot_data_intercepts(ax, x, y, m, n)
|
|
fig.set_facecolor("white")
|
|
fig.set_size_inches(7., 2.6)
|
|
fig.tight_layout()
|
|
fig.savefig("lin_regress.pdf")
|
|
plt.close()
|