import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
from plotstyle import *

rng = np.random.RandomState(637281)

# generate correlated data:
n = 200
a = 0.2
x = rng.randn(n);
y = rng.randn(n) + a*x;
#x = rng.exponential(1.0, n);
#y = rng.exponential(2.0, n) + a*x;

rd = np.corrcoef(x, y)[0, 1]

# permutation:
nperm = 1000
rs = []
for i in range(nperm) :
    xr=rng.permutation(x)
    yr=rng.permutation(y)
    rs.append( np.corrcoef(xr, yr)[0, 1] )

# pdf of the correlation coefficients:
h, b = np.histogram(rs, 20, density=True)

# significance:
rq = np.percentile(rs, 95.0)
print('Measured correlation coefficient = %.2f, correlation coefficient at 95%% percentile of bootstrap = %.2f' % (rd, rq))
ra = 1.0-0.01*st.percentileofscore(rs, rd)
print('Measured correlation coefficient %.2f is at %.4f percentile of bootstrap' % (rd, ra))

rp, ra = st.pearsonr(x, y)
print('Measured correlation coefficient %.2f is at %.4f percentile of test' % (rp, ra))

fig, ax = plt.subplots(figsize=cm_size(figure_width, 1.2*figure_height))
fig.subplots_adjust(**adjust_fs(left=4.0, bottom=2.7, right=0.5, top=1.0))
ax.annotate('Measured\ncorrelation\nis significant!',
            xy=(rd, 1.1), xycoords='data',
            xytext=(rd, 2.2), textcoords='data', ha='left',
            arrowprops=dict(arrowstyle="->", relpos=(0.2,0.0),
                connectionstyle="angle3,angleA=10,angleB=80") )
ax.annotate('95% percentile',
            xy=(0.14, 0.9), xycoords='data',
            xytext=(0.2, 4.0), textcoords='data', ha='left',
            arrowprops=dict(arrowstyle="->", relpos=(0.1,0.0),
                connectionstyle="angle3,angleA=30,angleB=70") )
ax.annotate('Distribution of\nuncorrelated\nsamples',
            xy=(-0.08, 3.6), xycoords='data',
            xytext=(-0.22, 5.0), textcoords='data', ha='left',
            arrowprops=dict(arrowstyle="->", relpos=(0.5,0.0),
                connectionstyle="angle3,angleA=150,angleB=100") )
ax.bar(b[:-1], h, width=b[1]-b[0], **fsC)
ax.bar(b[:-1][b[:-1]>=rq], h[b[:-1]>=rq], width=b[1]-b[0], **fsB)
ax.plot( [rd, rd], [0, 1], **lsA)
ax.set_xlim(-0.25, 0.35)
ax.set_xlabel('Correlation coefficient')
ax.set_ylabel('Probability density of H0')

plt.savefig('permutecorrelation.pdf')