#  ---------------------------------------------------------------------------------------------------------------------
# Name:        Firing Rate and Fourier Script (moving comb repro)
# Purpose:     Takes nixio spike data from moving comb repro and plots firing rate and power spectrum density graph
# Usage:	   python3 analysis_graphs.py average
# Author:      Carolin Sachgau, University of Tuebingen
# Created:     20/09/2018
#  ---------------------------------------------------------------------------------------------------------------------

import matplotlib.pyplot as plt
from IPython import embed
import sys
from icr_analysis import *
from open_nixio_new import *

#  Parameters
sampling_rate = 20000
sigma = 0.01  # for Gaussian
delay = 1.5  # delay in seconds after comb reaches one end, before commencing movement again
cell_name = sys.argv[1].split('/')[-2]

#  Open Nixio File
intervals_dict = open_nixio_new(sys.argv[1])


#  Kernel Density estimator: gaussian fit
t = np.arange(-sigma*4, sigma*4, 1/sampling_rate)
fxn = np.exp(-0.5*(t/sigma)**2) / np.sqrt(2*np.pi) / sigma  # gaussian function


# for (rep, speed, direct, pos, comb) in intervals_dict:
#     spike_train = intervals_dict[rep, speed, direct, pos, comb]
#     avg_convolve_spikes = gaussian_convolve(spike_train, fxn, sampling_rate)
#     p, freq, std_four, mn_four = fourier_psd(avg_convolve_spikes, sampling_rate)
#
#     #  Graphing
#     fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)
#
#     #  Firing Rate Graph
#     firing_times = np.arange(0, len(avg_convolve_spikes))
#     ax1.plot((firing_times / sampling_rate), avg_convolve_spikes)
#     ax1.set_title('Firing Rate of trial ' + str((speed, direct)) + ' comb = ' + str(comb) + '\n')
#     ax1.set_xlabel('Time (s)')
#     ax1.set_ylabel('Firing rate (Hz)')
#
#     #  Fourier Graph
#     ax2.semilogy(freq[freq < 200], p[freq < 200])
#     ax2.axhline(y=(mn_four+std_four), xmin=0, xmax=1, linestyle='--', color='red')
#     # ax2.axvline(x=max_four,linestyle='--', color='green')
#
#     plt.savefig(('nonavg_' + '_' + str(rep) + '_' + str(cell_name) + '_' + str(speed) + '_' + str(pos)
#                   + '_' + str(comb) + '_' + str(direct) + '.png'))
#     plt.close(fig)

for (rep, time, speed, direct, comb) in intervals_dict:
    spike_train = intervals_dict[(rep, time, speed, direct, comb)]
    avg_convolve_spikes = gaussian_convolve(spike_train, fxn, sampling_rate, time)
    p, freq, std_four, mn_four = fourier_psd(avg_convolve_spikes, sampling_rate)

    #  Graphing
    fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)

    #  Firing Rate Graph
    firing_times = np.arange(0, len(avg_convolve_spikes))
    ax1.plot((firing_times / sampling_rate), avg_convolve_spikes)
    ax1.set_title('Firing Rate of trial ' + str((speed, direct)) + ' comb = ' + str(comb) + '\n')
    ax1.set_xlabel('Time (s)')
    ax1.set_ylabel('Firing rate (Hz)')

    #  Fourier Graph
    ax2.semilogy(freq[freq < 400], p[freq < 400])
    ax2.axhline(y=(mn_four + std_four), xmin=0, xmax=1, linestyle='--', color='red')
    plt.tight_layout()
    plt.savefig((str(rep) + '_''avg_' + '_' + str(cell_name) + '_' + str(speed) + '_' + str(comb)
                 + '_' + str(direct) + '.png'))
    plt.close(fig)

#  ---------------------------------------------------------------------------------------------------------------------