import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import scipy.io as scio
from IPython import embed


def create_food_blotch(size_x=51, size_y=51, noise=0.01, threshold=0.008):
    x = np.arange(-np.round(size_x/2), np.round(size_x/2)+1)
    y = np.arange(-np.round(size_y/2), np.round(size_y/2)+1)
    X, Y = np.meshgrid(x, y)
    Z = mlab.bivariate_normal(X, Y, sigmax=10, sigmay=10)
    food = (np.random.rand(size_x, size_y) * noise) + Z
    return food


def create_world(width=100, height=100, d=0.1, food_sources=100, noise=0.01):
    my_world = np.zeros((int(width/d), int(height/d)))
    print("placing food sources ...")
    for i in range(food_sources):
        x = np.random.randint(0, width/d)
        y = np.random.randint(0, height/d)
        food = create_food_blotch(noise=noise)
        if (x + food.shape[1]) > my_world.shape[1]:
            x -= food.shape[1]
        if (y + food.shape[0]) > my_world.shape[0]:
            y -= food.shape[0]
        my_world[y:y+food.shape[0], x:x+food.shape[1]] += food
    return my_world


def get_step():
    options = [-1, 1]
    step = np.asarray([options[np.random.randint(0, 2)], options[np.random.randint(0, 2)]])
    return step


def is_valid_step(pos, step, max_x, max_y):
    new_pos = pos + step
    if (new_pos[0] < max_y) & (new_pos[1] < max_x) & (new_pos[0] >= 0) & (new_pos[1] >= 0):
        return True
    return False


def do_valid_step(pos, max_x, max_y):
    valid = False
    while not valid:
        step = get_step()
        valid = is_valid_step(pos, step, max_x, max_y)
    new_pos = pos + step
    return new_pos, step


def random_walk(world, steps=10000):
    x_positions = np.zeros(steps)
    y_positions = np.zeros(steps)
    start_x = np.random.randint(0, world.shape[1])
    start_y = np.random.randint(0, world.shape[0])
    eaten_food = 0
    new_pos = np.asarray([start_y, start_x])
    for i in range(steps):
        new_pos, step = do_valid_step(new_pos, world.shape[1], world.shape[0])
        x_positions[i] = new_pos[1]
        y_positions[i] = new_pos[0]
        eaten_food += world[new_pos[0], new_pos[1]]
        world[new_pos[0], new_pos[1]] = 0.0
    return x_positions, y_positions, eaten_food


def not_so_random_walk(world, steps=10000):
    x_positions = np.zeros(steps)
    y_positions = np.zeros(steps)
    start_x = np.random.randint(0, world.shape[1])
    start_y = np.random.randint(0, world.shape[0])
    previous_food = 0.0
    current_food = 0.0
    eaten_food = 0
    pos = np.asarray([start_y, start_x])
    step = None
    for i in range(steps):
        gradient = current_food - previous_food
        if (gradient <= 0) or ((step is not None) and \
                               (not is_valid_step(pos, step, world.shape[1], world.shape[0]))):
            pos, step = do_valid_step(pos, world.shape[1], world.shape[0])
        else:
            pos += step
        x_positions[i] = pos[1]
        y_positions[i] = pos[0]
        previous_food = current_food
        current_food = world[pos[0], pos[1]]
        eaten_food += current_food
        world[pos[0], pos[1]] -= current_food
    return x_positions, y_positions, eaten_food


if __name__ == '__main__':
    print("create world... ")
    world = create_world(noise=0.0, food_sources=50)
    trials = 10
    gain_random = np.zeros(trials)
    gain_nsr = np.zeros(trials)
    print("run")
    for i in range(trials):
        x, y, gain_random[i] = random_walk(world.copy(), 100000)
        x2, y2, gain_nsr[i] = not_so_random_walk(world.copy(), 100000)

    print("random walk yields: %.2f +- %.2f food per 100000 steps" % (np.mean(gain_random),
                                                                      np.std(gain_random)))
    print("not so random walk yields: %.2f +- %.2f food per 100000 steps" % (np.mean(gain_nsr),
                                                                             np.std(gain_nsr)))

    scio.savemat('random_world.mat', {'world': world})
    plt.imshow(world)
    plt.scatter(x[::2], y[::2], s=0.5, color='red')
    plt.scatter(x2[::2], y2[::2], s=0.5, color='green')
    plt.show()