import numpy as np
import scipy.io as scio
import os
from IPython import embed


def boltzmann(x, y_max, slope, inflection):
    """
    The underlying Boltzmann function.
    .. math::
        f(x) = y_max / \exp{-slope*(x-inflection}

    :param x: The x values.
    :param y_max: The maximum value.
    :param slope: The slope parameter k
    :param inflection: the position of the inflection point.
    :return: the y values.
    """
    y = y_max / (1 + np.exp(-slope * (x - inflection)))
    return y


class Animal(object):

    def __init__(self, delay, learning_rate, volatility, responsiveness):
        """
        :param delay:
        :param learning_rate: delta percent_correct per session
        :param volatility: 0 -> 1 the noise in the decision
        :param responsiveness: 0 -> 1 probability of actually conducting a trial
        """
        self.__delay = delay
        self.__learning_rate = learning_rate
        self.__volatility = volatility
        self.__responsiveness = responsiveness

    def simulate(self, session_count=10, trials=20, task_difficulties=[]):
        """
        :param task_difficulties gives a malus on the learning rate range 0 - 1
        """
        tasks = 1 if len(task_difficulties) == 0 else len(task_difficulties)
        if len(task_difficulties) == 0:
            task_difficulties = [0]
        avg_perf = np.zeros((session_count, tasks))
        err_perf = np.zeros((session_count, tasks))
        trials_performed = np.zeros((session_count, tasks))
        for i in range(session_count):
            for j in range(tasks):
                learning_rate = self.__learning_rate * (1-task_difficulties[j])
                base_performance = boltzmann(i, 1.0, learning_rate, self.__delay) * 0.5 + 0.5
                noise = np.random.randn(trials) * (self.__volatility * (1-task_difficulties[j]))
                performances = np.random.rand(trials) < (base_performance + noise)
                trials_completed = np.random.rand(trials) < self.__responsiveness
                trials_performed[i, j] = np.sum(trials_completed)
                avg_perf[i, j] = np.sum(performances[trials_completed]) / trials_performed[i, j]
                err_perf[i, j] = np.sqrt(trials_performed[i, j] * (avg_perf[i, j]/100) * (1 - avg_perf[i, j]))

        return avg_perf, err_perf, trials_performed


def save_performance(avg_perf, err_perf, trials_completed, tasks, animal_id):
    result_folder="experiment"
    for i in range(avg_perf.shape[0]):
        performance = avg_perf[i, :]
        error = err_perf[i, :]
        trials = trials_completed[i, :]
        scio.savemat(os.path.join(result_folder, "Animal_%i_Session_%i.mat" % (animal_id, i+1)),
                     {"performance": performance, "perf_std": error, "trials": trials, "tasks": tasks})


if __name__ == "__main__":
    session_count = [25, 32, 40, 30]
    task_difficulties = [0, 0.75, 0.95]
    delays = [5, 10, 12, 20]
    learning_rates = np.array([0.25, 0.5, 1., 1.5])
    volatilities = np.random.rand(4) * 0.25
    responsivness = np.random.rand(4) * 0.25 + 0.75
    for i in range(len(delays)):
        d = delays[i]
        lr = learning_rates[i]
        v = volatilities[i]
        r = responsivness[i]
        a = Animal(d, lr, v, r)
        ap, ep, tp = a.simulate(session_count=session_count[i], task_difficulties=task_difficulties)
        save_performance(ap, ep, tp, ['a', 'b', 'c'], i+1)