import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
from plotstyle import *

# normal distribution:
x = np.arange(-3.0, 3.0, 0.01)
g = np.exp(-0.5*x*x)/np.sqrt(2.0*np.pi)

fig, (ax1, ax2) = plt.subplots(1, 2)
fig.subplots_adjust(**adjust_fs(bottom=2.7, top=0.1))
ax1.set_xlabel('x')
ax1.set_ylabel('Prob. density p(x)')
ax1.set_ylim(0.0, 0.46)
ax1.set_yticks(np.arange(0.0, 0.45, 0.1))
ax1.text(-1.0, 0.06, '50%', ha='center')
ax1.text(+1.0, 0.06, '50%', ha='center')
ax1.annotate('Median\n= mean',
            xy=(0.1, 0.3), xycoords='data',
            xytext=(1.2, 0.35), textcoords='data', ha='left',
            arrowprops=dict(arrowstyle="->", relpos=(0.0,0.2),
            connectionstyle="angle3,angleA=10,angleB=40"))
ax1.annotate('Mode',
            xy=(-0.1, 0.4), xycoords='data',
            xytext=(-2.5, 0.43), textcoords='data', ha='left',
            arrowprops=dict(arrowstyle="->", relpos=(0.0,0.2),
            connectionstyle="angle3,angleA=10,angleB=120"))
ax1.fill_between(x[x<0], 0.0, g[x<0], **fsCs)
ax1.fill_between(x[x>0], 0.0, g[x>0], **fsFs)
ax1.plot(x, g, **lsA)
ax1.plot([0.0, 0.0], [0.0, 0.45], **lsMarker)

# gamma distribution:
x = np.arange(0.0, 6.0, 0.01)
shape = 2.0
g = st.gamma.pdf(x, shape)
m = st.gamma.median(shape)
gm = st.gamma.mean(shape)
ax2.set_xlabel('x')
ax2.set_ylabel('Prob. density p(x)')
ax2.set_ylim(0.0, 0.46)
ax2.set_yticks(np.arange(0.0, 0.45, 0.1))
ax2.text(m-0.8, 0.06, '50%', ha='center')
ax2.text(m+1.2, 0.06, '50%', ha='center')
ax2.annotate('Median',
            xy=(m+0.1, 0.2), xycoords='data',
            xytext=(m+1.6, 0.25), textcoords='data', ha='left',
            arrowprops=dict(arrowstyle="->", relpos=(0.0,0.5),
            connectionstyle="angle3,angleA=30,angleB=70"))
ax2.annotate('Mean',
            xy=(gm, 0.01), xycoords='data',
            xytext=(gm+1.8, 0.15), textcoords='data', ha='left',
            arrowprops=dict(arrowstyle="->", relpos=(0.0,0.5),
            connectionstyle="angle3,angleA=0,angleB=90"))
ax2.annotate('Mode',
            xy=(1.0, 0.38), xycoords='data',
            xytext=(1.8, 0.42), textcoords='data', ha='left',
            arrowprops=dict(arrowstyle="->", relpos=(0.0,0.5),
            connectionstyle="angle3,angleA=0,angleB=70"))
ax2.fill_between(x[x<m], 0.0, g[x<m], **fsCs)
ax2.fill_between(x[x>m], 0.0, g[x>m], **fsFs)
ax2.plot(x, g, **lsA)
ax2.plot([m, m], [0.0, 0.38], **lsMarker)
#ax2.plot([gm, gm], [0.0, 0.38], **lsMarker)

fig.savefig('median.pdf')