import numpy as np
import matplotlib.pyplot as plt
from plotstyle import *

def random_walk(n, p, rng):
    steps = rng.rand(n)
    steps[steps>=1.0-p] = 1
    steps[steps<1.0-p] = -1
    x = np.hstack(((0.0,), np.cumsum(steps)))
    return x

def plot_random_walk(ax, nmax, p, rng, ymin=-20.0, ymax=20.0):
    nn = np.linspace(0.0, nmax, 200)
    m = 2*p-1
    v = 4*p*(1-p)
    ax.fill_between(nn, m*nn+np.sqrt(nn*v), m*nn-np.sqrt(nn*v), **fsAa)
    ax.plot([0.0, nmax], [0.0, m*nmax], **lsAm)
    lcs = [colors['red'], colors['orange'],
           colors['yellow'], colors['green']]
    n = np.arange(0.0, nmax+1, 1.0)
    for k in range(12):
        x = random_walk(nmax, p, rng)
        ls = dict(**lsAm)
        ls['color'] = lcs[k%len(lcs)]
        ax.plot(n, x, **ls)
    ax.set_xlabel('Iteration $n$')
    ax.set_ylabel('Position $x_n$')
    ax.set_xlim(0, nmax)
    ax.set_ylim(ymin, ymax)
    ax.set_yticks(np.arange(ymin, ymax+1.0, 10.0))

    
if __name__ == "__main__":
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=cm_size(figure_width, 2.0*figure_height))
    fig.subplots_adjust(**adjust_fs(fig, right=1.0))
    rng = np.random.RandomState(52281)
    plot_random_walk(ax1, 80, 0.5, rng)
    ax1.text(5.0, 20.0, 'symmetric', va='center')
    plot_random_walk(ax2, 80, 0.6, rng, -10.0, 30.0)
    ax2.text(5.0, 30.0, 'with drift', va='center')
    fig.savefig("randomwalkone.pdf")
    plt.close()