import numpy as np
from scipy.io import savemat
import matplotlib.pyplot as plt

def lifadaptspikes(stimulus, gain=10.0, trials=50, duration=200.0, before=200.0, after=400.0):
    dt = 0.1
    s0 = 9.5
    tau = 10.0
    D = 1.0
    taua = 60.0
    da = 100.0
    Vth = 10.0
    n = int((duration+before+after)/dt)
    delay = 10.0
    sig = np.zeros(n) + s0
    n1 = int((before+delay)/dt)
    n2 = int((before+duration+delay)/dt)
    sig[n1:n2] += (gain * stimulus - 1.0) * np.exp(-np.arange(n2-n1)*dt/0.4/duration)
    spikes = []
    for j in range(trials):
        noise = np.sqrt(2.0*D/dt)*np.random.randn(n)
        V = np.random.rand()*Vth
        A = 0.0
        s = s0
        As = np.zeros(n)
        times = []
        for k in range(n):
            As[k] = A
            V += (-V-A+sig[k]+noise[k])*dt/tau
            A += (-A)*dt/taua
            if V >= Vth:
                V = 0.0
                A += da/taua
                times.append(k*dt-before)
        spikes.append(np.array(times))
        #plt.plot(np.arange(n)*dt-before, As)
        #plt.plot(np.arange(n)*dt-before, sig)
        #return spikes
    return spikes

def firingrate(spikes, tmin, tmax):
    rates = []
    for st in spikes:
        times = st[(st>=tmin)&(st<=tmax)]
        r = len(times)/(tmax-tmin)
        rates.append(1000.0*r)
    return np.mean(rates), np.std(rates)

"""
nangles = 12
angles = 180.0*np.arange(nangles)/nangles
rates = np.zeros(nangles)
ratessd = np.zeros(nangles)
allspikes = []
for k, angle in enumerate(angles):
    spikes = lifadaptspikes(0.5*(1.0-np.cos(2.0*np.pi*angle/180.0)))
    rm, rsd = firingrate(spikes, 0.0, 200.0)
    rates[k] = rm
    ratessd[k] = rsd
    allspikes.append(spikes)
    #plt.subplot(2, 5, k+1)
    #plt.title('%g' % (angle/2.0/np.pi))
    #plt.eventplot(spikes, colors=['r'])
plt.plot(angles, rates, 'r', lw=2)
plt.plot(angles, rates+ratessd, 'b')
plt.plot(angles, rates-ratessd, 'b')
plt.show()
"""

# tuning curves:
nunits = 6
unitphases = np.linspace(0.0, 1.0, nunits) + 0.05*np.random.randn(nunits)/float(nunits)
unitgains = 15.0 + 5.0*(2.0*np.random.rand(nunits)-1.0)
nangles = 12
angles = 180.0*np.arange(nangles)/nangles
for unit, (phase, gain) in enumerate(zip(unitphases, unitgains)):
    print '%.1f %.0f' % (gain, phase*180.0)
    allspikes = []
    for k, angle in enumerate(angles):
        spikes = lifadaptspikes(0.5*(1.0-np.cos(2.0*np.pi*(angle/180.0-phase))), gain)
        allspikes.append(spikes)
    spikesobj = np.zeros((len(allspikes), len(allspikes[0])), dtype=np.object)
    for k in range(len(allspikes)):
        for j in range(len(allspikes[k])):
            spikesobj[k, j] = 0.001*allspikes[k][j]
    savemat('unit%d.mat'%(unit+1), mdict={'angles': angles, 'spikes': spikesobj})

# population activity:
nangles = 50
angles = 180.0*np.random.rand(nangles)
for k, angle in enumerate(angles):
    print '%.0f' % angle
    allspikes = []
    for unit, (phase, gain) in enumerate(zip(unitphases, unitgains)):
        spikes = lifadaptspikes(0.5*(1.0-np.cos(2.0*np.pi*(angle/180.0-phase))), gain)
        allspikes.append(spikes)
    spikesobj = np.zeros((len(allspikes), len(allspikes[0])), dtype=np.object)
    for i in range(len(allspikes)):
        for j in range(len(allspikes[i])):
            spikesobj[i, j] = 0.001*allspikes[i][j]
    savemat('population%02d.mat'%(k+1), mdict={'spikes': spikesobj,
                                               'angle': angle,
                                               'phases': unitphases,
                                               'gains': unitgains})