import matplotlib.pyplot as plt
import numpy as np
import pylab
from IPython import embed
from scipy.optimize import curve_fit
from scipy.optimize import curve_fit
from matplotlib.mlab import specgram
import os

from jar_functions import import_data
from jar_functions import import_amfreq
from jar_functions import sin_response
from jar_functions import mean_noise_cut
from jar_functions import gain_curve_fit

#plt.rcParams.update({'font.size': 10})

def take_second(elem):      # function for taking the names out of files
    return elem[1]

identifier = ['2018lepto1',
              '2018lepto4',
              '2018lepto5',
              '2018lepto76',
              '2018lepto98',
              '2019lepto03',
              '2019lepto24',
              '2019lepto27',
              '2019lepto30',
              '2020lepto04',
              '2020lepto06',
              '2020lepto16',
              '2020lepto19',
              '2020lepto20'
              ]
for ident in identifier:

    times = []
    jars = []
    jms = []
    amfreq = []

    times1 = []
    jars1 = []
    jms1 = []
    amfreq1 = []

    amf = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]

    data = sorted(np.load('%s files.npy' %ident), key = take_second)      # list with filenames in it

    for i, d in enumerate(data):
        dd = list(d)
        if dd[1] == '1' or dd[1] == '0.2' or dd[1] == '0.05' or dd[1] == '0.01' or dd[1] == '0.005' or dd[1] == '0.001':
            jar = np.load('%s.npy' %dd)     # load data for every file name
            jm = jar - np.mean(jar)         # low-pass filtering by subtracting mean

            time = np.load('%s time.npy' %dd)       # time file
            dt = time[1] - time[0]

            n = int(1/float(d[1])/dt)
            cutf = mean_noise_cut(jm, n = n)
            cutt = time
            if dd[1] == '0.001':
                amfreq1.append(dd[1])
                jars1.append(jm - cutf)
                jms1.append(jm)
                times1.append(time)
            if dd[1] not in amfreq:
                print(dd)
                amfreq.append(dd[1])
                jars.append(jm - cutf)
                jms.append(jm)
                times.append(time)
            else:
                print('1:', dd)
                amfreq1.append(dd[1])
                jars1.append(jm - cutf)
                jms1.append(jm)
                times1.append(time)
    if len(jars) != 6:
        continue

    fig = plt.figure(figsize=(8.27,11.69))
    fig.suptitle('%s' %ident)
    fig.text(0.06, 0.5, 'frequency [Hz]', ha='center', va='center', rotation='vertical')
    fig.text(0.5, 0.04, 'time [s]', ha='center', va='center')

    ax0 = fig.add_subplot(611)
    ax0.plot(times[0], jms[0])
    #ax0.plot(times[0], jars[0])
    ax0.set_ylim(-12, 12)
    #plt.text(-0.1, 1.05, "A)", fontweight=550, transform=ax0.transAxes)

    ax1 = fig.add_subplot(612)
    ax1.plot(times[1], jms[1])
    #ax1.plot(times[1], jars[1])
    ax1.set_ylim(-12, 12)
    #plt.text(-0.1, 1.05, "B)", fontweight=550, transform=ax1.transAxes)

    ax2 = fig.add_subplot(613)
    ax2.plot(times[2], jms[2])
    #ax2.plot(times[2], jars[2])
    ax2.set_ylim(-12, 12)
    #plt.text(-0.1, 1.05, "C)", fontweight=550, transform=ax2.transAxes)

    ax3 = fig.add_subplot(614)
    ax3.plot(times[3], jms[3])
    #ax3.plot(times[3], jars[3])
    ax3.set_ylim(-12, 12)
    #plt.text(-0.1, 1.05, "D)", fontweight=550, transform=ax3.transAxes)

    ax4 = fig.add_subplot(615)
    ax4.plot(times[4], jms[4])
    #ax4.plot(times[4], jars[4])
    ax4.set_ylim(-12, 12)
   # plt.text(-0.1, 1.05, "E)", fontweight=550, transform=ax4.transAxes)

    ax5 = fig.add_subplot(616)
    ax5.plot(times[5], jms[5])
    #ax5.plot(times[5], jars[5])
    ax5.set_ylim(-12, 12)
    #plt.text(-0.1, 1.05, "F)", fontweight=550, transform=ax5.transAxes)

    plt.subplots_adjust(left=0.125,
                        bottom=0.1,
                        right=0.9,
                        top=0.9,
                        wspace=0.2,
                        hspace=0.35)
    plt.show()