74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import scipy.io as scio
|
|
from IPython import embed
|
|
|
|
def boltzmann(x, y_max, slope, inflection):
|
|
"""
|
|
The underlying Boltzmann function.
|
|
.. math::
|
|
f(x) = y_max / \exp{-slope*(x-inflection}
|
|
|
|
:param x: The x values.
|
|
:param y_max: The maximum value.
|
|
:param slope: The slope parameter k
|
|
:param inflection: the position of the inflection point.
|
|
:return: the y values.
|
|
"""
|
|
y = y_max / (1 + np.exp(-slope * (x - inflection)))
|
|
return y
|
|
|
|
|
|
class Animal(object):
|
|
|
|
def __init__(self, delay, learning_rate, volatility, responsiveness):
|
|
"""
|
|
:param delay:
|
|
:param learning_rate: delta percent_correct per session
|
|
:param volatility: 0 -> 1 the noise in the decision
|
|
:param responsiveness: 0 -> 1 probability of actually conducting a trial
|
|
"""
|
|
self.__delay = delay
|
|
self.__learning_rate = learning_rate
|
|
self.__volatility = volatility
|
|
self.__responsiveness = responsiveness
|
|
|
|
def simulate(self, session_count=10, trials=20, task_difficulties=[]):
|
|
tasks = 1 if len(task_difficulties) == 0 else len(task_difficulties)
|
|
if len(task_difficulties) == 0:
|
|
task_difficulties = [0]
|
|
avg_perf = np.zeros((session_count, tasks))
|
|
err_perf = np.zeros((session_count, tasks))
|
|
trials_performed = np.zeros(session_count)
|
|
for i in range(session_count):
|
|
for j in range(tasks):
|
|
base_performance = boltzmann(i, 1.0, self.__learning_rate/20, self.__delay)
|
|
penalty = base_performance * task_difficulties[j] * 0.5
|
|
base_perf = 50 + 50 * (base_performance - penalty)
|
|
trials_completed = np.random.rand(trials) < self.__responsiveness
|
|
performances = np.random.randn(trials) * self.__volatility * 100 + base_perf
|
|
avg_perf[i, j] = np.mean(performances[trials_completed])
|
|
err_perf[i, j] = np.std(performances[trials_completed])
|
|
trials_performed = np.sum(trials_completed)
|
|
return avg_perf, err_perf, trials_performed
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
session_count = 30
|
|
task_difficulties = [0, 0.3, 1.]
|
|
|
|
delays = [5, 10, 12, 20]
|
|
learning_rates = np.array([5, 10, 2, 20])
|
|
volatilities = np.random.rand(4) * 0.5
|
|
responsivness = np.random.rand(4) * 0.5 + 0.5
|
|
|
|
for i in range(len(delays)):
|
|
d = delays[i]
|
|
lr = learning_rates[i]
|
|
v = volatilities[i],
|
|
r = responsivness[i]
|
|
a = Animal(d, lr, v, r)
|
|
ap, ep, tp = a.simulate(session_count=session_count, task_difficulties=[0, 0.3, 0.6])
|
|
plt.plot(ap)
|
|
embed() |