import pandas as pd
import h5py
import json
import os
import numpy as np
from ast import literal_eval
# todo: run some and test
# todo: develop consistent file structure
#  todo: Model fI
# %% todo: Model fI
# for each model
# | (index) | mag | alt   | type  | F     | I     |
# | 0       | -10 | m     | shift | array | array |
models = ['RS_pyramidal', 'RS_inhib', 'FS', 'RS_pyramidal_Kv', 'RS_inhib_Kv', 'FS_Kv', 'Cb_stellate', 'Cb_stellate_Kv',
          'Cb_stellate_Kv_only', 'STN', 'STN_Kv', 'STN_Kv_only']
model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'RS pyramidal +Kv1.1', 'RS inhibitory +Kv1.1', 'FS +Kv1.1',
               'Cb stellate', 'Cb stellate +Kv1.1', 'Cb stellate $\Delta$Kv1.1', 'STN', 'STN +Kv1.1',
               'STN $\Delta$Kv1.1']
# for each model get csv file with all shifts, scale and g
for model_name in models:
    df = pd.DataFrame(columns=['mag', 'alt', 'type', 'F', 'I'])
    folder = '../Neuron_models/{}'.format(model_name)
    fname = os.path.join(folder, "{}.hdf5".format(model_name))
    # for each alt in model
    #     with h5py.File(fname, "r+") as f:
    #         df.loc[model_name, 'alt'] = f['data'].attrs['alteration']
    #         test = f['data'].attrs['alteration_info'].replace(' ', ',')
    #         alt_info = literal_eval(test)
    #         var = alt_info[0]
    #         alt_type = alt_info[1]
    #         df.loc[model_name, 'mag'] = var
    #         df.loc[model_name, 'type'] = alt_type
    #         df.loc[model_name, 'F'] = f['analysis']['F_inf'][:]
    #         I_mag = np.arange(f['data'].attrs['I_low'], f['data'].attrs['I_high'],
    #                           (f['data'].attrs['I_high'] - f['data'].attrs['I_low']) / f['data'].attrs['stim_num']) * 1000
    #         df.loc[model_name, 'I'] = I
    #         df.to_csv('./Model_fI/{}.csv'.format(model_name))
# %% todo: rheo/AUC_{}_corr
# | (index) | model | corr   | p_value  | g | color
rheo_corr = pd.DataFrame(columns=['model', 'corr', 'p_value', 'g', 'color'])
rheo_corr.to_csv('rheo_corr.csv')

AUC_corr = pd.DataFrame(columns=['model', 'corr', 'p_value', 'g', 'color'])
AUC_corr.to_csv('AUC_corr.csv')


# # AUC_shift = pd.DataFrame(columns=['alteration', 'RS Pyramidal','RS Inhibitory','FS','IB',
# #                                   'RS Pyramidal +$K_V1.1$','RS Inhibitory +$K_V1.1$',
# #                 'FS +$K_V1.1$','IB +$K_V1.1$','Cb stellate','Cb stellate +$K_V1.1$',
# #           'Cb stellate $\Delta$$K_V1.1$','STN','STN +$K_V1.1$',
# #           'STN $\Delta$$K_V1.1$'])
# #
# # AUC_slope = pd.DataFrame(columns=['alteration','RS Pyramidal','RS Inhibitory','FS','IB','RS Pyramidal +$K_V1.1$','RS Inhibitory +$K_V1.1$',
# #                 'FS +$K_V1.1$','IB +$K_V1.1$',
# #                                   'Cb stellate','Cb stellate +$K_V1.1$',
# #           'Cb stellate $\Delta$$K_V1.1$','STN','STN +$K_V1.1$',
# #           'STN $\Delta$$K_V1.1$'])
# #
# # AUC_g = pd.DataFrame(columns=['alteration','RS Pyramidal','RS Inhibitory','FS','IB','RS Pyramidal +$K_V1.1$','RS Inhibitory +$K_V1.1$',
# #                 'FS +$K_V1.1$','IB +$K_V1.1$',
# #                               'Cb stellate','Cb stellate +$K_V1.1$',
# #           'Cb stellate $\Delta$$K_V1.1$','STN','STN +$K_V1.1$',
# #           'STN $\Delta$$K_V1.1$'])
# #
# # script_dir = os.path.dirname(os.path.realpath("__file__"))
# # fname = os.path.join(script_dir, )
# # # f = h5py.File(fname, "r")
# #
# # models = ['RS_pyramidal', 'RS_inhib', 'FS', 'IB','Cb_stellate','Cb_stellate_Kv','Cb_stellate_Kv_only','STN','STN_Kv',
# #           'STN_Kv_only']
# # model_labels = ['RS Pyramidal +$K_V1.1$','RS Inhibitory +$K_V1.1$',
# #                 'FS +$K_V1.1$','IB +$K_V1.1$','Cb stellate','Cb stellate +$K_V1.1$',
# #           'Cb stellate $\Delta$$K_V1.1$','STN','STN +$K_V1.1$',
# #           'STN $\Delta$$K_V1.1$']
# # posp_models = ['RS_pyramidal', 'RS_inhib', 'FS', 'IB']
# # posp_model_labels = ['RS Pyramidal','RS Inhibitory', 'FS','IB']
# #
# #
# # shift_interest = 'n'
# # for i in range(len(models)):
# #     with open('./SA_summary_df/{}_shift_AUC_rel_acc.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['0', :])/ data.loc['0', :] # normalize AUC
# #         data.sort_index(inplace=True)
# #         AUC_shift[model_labels[i]] =data[shift_interest]
# # for i in range(len(posp_models)):
# #     with open('./SA_summary_df_pospischil/{}_shift_AUC_rel_acc_pospischil.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['0', :])/ data.loc['0', :] # normalize AUC
# #         data.sort_index(inplace=True)
# #         AUC_shift[posp_model_labels[i]] =data[shift_interest]
# # AUC_shift['alteration'] = AUC_shift.index
# #
# #
# #
# # slope_interest = 's'
# # for i in range(len(models)):
# #     with open('./SA_summary_df/{}_slope_AUC_rel_acc.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['1.0', :])/ data.loc['1.0', :] # normalize AUC
# #         data.sort_index(inplace=True)
# #         try:
# #             AUC_slope[model_labels[i]] = data[slope_interest]
# #         except:
# #             pass
# # for i in range(len(posp_models)):
# #     with open('./SA_summary_df_pospischil/{}_slope_AUC_rel_acc_pospischil.json'.format(models[i])) as json_file:
# #
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['1.0', :])/ data.loc['1.0', :] # normalize AUC
# #         data.sort_index(inplace=True)
# #         try:
# #             AUC_slope[posp_model_labels[i]] =data[slope_interest]
# #         except:
# #             pass
# # AUC_slope['alteration'] = AUC_slope.index
# #
# # g_interest = 'Kd'
# # for i in range(len(models)):
# #     with open('./SA_summary_df/{}_g_AUC_rel_acc.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['1.0', :])/ data.loc['1.0', :] # normalize AUC
# #         data.sort_index(inplace=True)
# #         AUC_g[model_labels[i]] =data[g_interest]
# # for i in range(len(posp_models)):
# #     with open('./SA_summary_df_pospischil/{}_g_AUC_rel_acc_pospischil.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['1.0', :])/ data.loc['1.0', :] # normalize AUC
# #         data.sort_index(inplace=True)
# #         AUC_g[posp_model_labels[i]] =data[g_interest]
# # AUC_g['alteration'] = AUC_g.index
# #
# # AUC_shift.to_csv('AUC_shift_ex.csv')
# # AUC_slope.to_csv('AUC_slope_ex.csv')
# # AUC_g.to_csv('AUC_g_ex.csv')
#
#
# # rheo_shift = pd.DataFrame(columns=['alteration', 'RS Pyramidal', 'RS Inhibitory', 'FS', 'IB',
# #                                    'RS Pyramidal +$K_V1.1$','RS Inhibitory +$K_V1.1$',
# #                 'FS +$K_V1.1$','IB +$K_V1.1$','Cb stellate',
# #                                   'Cb stellate +$K_V1.1$',
# #                                   'Cb stellate $\Delta$$K_V1.1$', 'STN',
# #                                   'STN +$K_V1.1$',
# #                                   'STN $\Delta$$K_V1.1$'])
# #
# # rheo_slope = pd.DataFrame(columns=['alteration', 'RS Pyramidal', 'RS Inhibitory', 'FS', 'IB','RS Pyramidal +$K_V1.1$',
# #                                    'RS Inhibitory +$K_V1.1$',
# #                 'FS +$K_V1.1$','IB +$K_V1.1$',
# #                                    'Cb stellate',
# #                                   'Cb stellate +$K_V1.1$',
# #                                   'Cb stellate $\Delta$$K_V1.1$', 'STN',
# #                                   'STN +$K_V1.1$',
# #                                   'STN $\Delta$$K_V1.1$'])
# #
# # rheo_g = pd.DataFrame(columns=['alteration', 'RS Pyramidal', 'RS Inhibitory', 'FS', 'IB',
# #                                'RS Pyramidal +$K_V1.1$','RS Inhibitory +$K_V1.1$',
# #                 'FS +$K_V1.1$','IB +$K_V1.1$','Cb stellate',
# #                               'Cb stellate +$K_V1.1$',
# #                               'Cb stellate $\Delta$$K_V1.1$', 'STN',
# #                               'STN +$K_V1.1$',
# #                               'STN $\Delta$$K_V1.1$'])
# #
# # script_dir = os.path.dirname(os.path.realpath("__file__"))
# # fname = os.path.join(script_dir, )
# # # f = h5py.File(fname, "r")
# #
# # models = ['RS_pyramidal', 'RS_inhib', 'FS', 'IB', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
# #           'STN_Kv',
# #           'STN_Kv_only']
# # model_labels = ['RS Pyramidal +$K_V1.1$', 'RS Inhibitory +$K_V1.1$', 'FS +$K_V1.1$', 'IB +$K_V1.1$', 'Cb stellate',
# #                 'Cb stellate +$K_V1.1$',
# #                 'Cb stellate $\Delta$$K_V1.1$', 'STN',
# #                 'STN +$K_V1.1$',
# #                 'STN $\Delta$$K_V1.1$']
# # posp_models = ['RS_pyramidal', 'RS_inhib', 'FS', 'IB']
# # posp_model_labels = ['RS Pyramidal','RS Inhibitory', 'FS','IB']
# #
# # shift_interest = 's'
# # for i in range(len(models)):
# #     with open('./SA_summary_df/{}_shift_rheo.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['0', :]) #/ data.loc['0', :]  # normalize AUC
# #         data.sort_index(inplace=True)
# #         try:
# #             rheo_shift[model_labels[i]] = data[shift_interest]
# #         except:
# #             pass
# # for i in range(len(posp_models)):
# #     with open('./SA_summary_df_pospischil/{}_shift_rheo_pospischil.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['0', :]) #/ data.loc['0', :]  # normalize AUC
# #         data.sort_index(inplace=True)
# #         try:
# #             rheo_shift[posp_model_labels[i]] = data[shift_interest]
# #         except:
# #             pass
# # rheo_shift['alteration'] = rheo_shift.index
# #
# # slope_interest = 'u'
# # for i in range(len(models)):
# #     with open('./SA_summary_df/{}_slope_rheo.json'.format(models[i])) as json_file:
# #
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['1.0', :]) #/ data.loc['1.0', :]  # normalize AUC
# #         data.sort_index(inplace=True)
# #         try:
# #             rheo_slope[model_labels[i]] = data[slope_interest]
# #         except:
# #             pass
# # for i in range(len(posp_models)):
# #     with open('./SA_summary_df_pospischil/{}_slope_rheo_pospischil.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['1.0', :]) #/ data.loc['1.0', :]  # normalize AUC
# #         data.sort_index(inplace=True)
# #         # if models[i] == 'STN_Kv_only' or models[i] == 'Cb_stellate_Kv_only':
# #         #     data = data.drop(columns=['A'])
# #         try:
# #             rheo_slope[posp_model_labels[i]] = data[slope_interest]
# #         except:
# #             pass
# # rheo_slope['alteration'] = rheo_slope.index
# #
# # g_interest = 'Leak'
# # for i in range(len(models)):
# #     with open('./SA_summary_df/{}_g_rheo.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['1.0', :]) #/ data.loc['1.0', :]  # normalize AUC
# #         data.sort_index(inplace=True)
# #         rheo_g[model_labels[i]] = data[g_interest]
# # for i in range(len(posp_models)):
# #     with open('./SA_summary_df_pospischil/{}_g_rheo_pospischil.json'.format(models[i])) as json_file:
# #         data = pd.read_json(json_file, convert_dates=False, convert_axes=False)
# #         data.replace(0., np.NaN, inplace=True)
# #         data = (data - data.loc['1.0', :]) #/ data.loc['1.0', :]  # normalize AUC
# #         data.sort_index(inplace=True)
# #         rheo_g[posp_model_labels[i]] = data[g_interest]
# # rheo_g['alteration'] = rheo_g.index
# #
# #
# # rheo_shift.to_csv('rheo_shift_ex.csv')
# # rheo_slope.to_csv('rheo_slope_ex.csv')
# # rheo_g.to_csv('rheo_g_ex.csv')
#
# %% firing_values.csv,model_spiking.csv, model_F_inf.csv
import numpy as np
import pandas as pd

models = ['RS_pyramidal', 'RS_inhib', 'FS', 'RS_pyramidal_Kv', 'RS_inhib_Kv', 'FS_Kv', 'Cb_stellate', 'Cb_stellate_Kv',
          'Cb_stellate_Kv_only', 'STN', 'STN_Kv', 'STN_Kv_only']
model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'RS pyramidal +Kv1.1', 'RS inhibitory +Kv1.1', 'FS +Kv1.1',
               'Cb stellate', 'Cb stellate +Kv1.1', 'Cb stellate $\Delta$Kv1.1', 'STN', 'STN +Kv1.1',
               'STN $\Delta$Kv1.1']
firing_values = pd.DataFrame(columns=models, index=['spike_ind', 'ramp_up', 'ramp_down'])
# firing_values.loc['spike_ind', :] = np.array(
#     [0.3, 0.04375, 0.25, 0.3, 0.0875, 0.25, 0.3, 0.65, 0.375, 0.125, 0.475, 0.4])

models = ['RS_pyr', 'RS_pyr_Kv', 'RS_inhib', 'RS_inhib_Kv', 'FS', 'FS_Kv',
          'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Delta_Kv',
          'STN', 'STN_Kv', 'STN_Delta_Kv']
col_names = ['I', 'I_inhib']
for mod in models: col_names.append(mod)
model_F_inf = pd.DataFrame(columns=col_names)
col_names = ['t']
for mod in models: col_names.append(mod)
spiking = pd.DataFrame(columns=col_names)


# index for example trace
spike_ind = {'RS_pyramidal': 60, 'RS_inhib':25, 'FS':50, 'RS_pyramidal_Kv':60,
            'RS_inhib_Kv':50, 'FS_Kv':50, 'Cb_stellate':60, 'Cb_stellate_Kv':130,
            'Cb_stellate_Kv_only':75, 'STN': 25, 'STN_Kv':95, 'STN_Kv_only':80}

for model_name in models:
    folder = '../Neuron_models/{}'.format(model_name)
    fname = os.path.join(folder, "{}.hdf5".format(model_name))  # RS_inhib
    with h5py.File(fname, "r+") as f:
        I_mag = np.arange(f['data'].attrs['I_low'], f['data'].attrs['I_high'],
                          (f['data'].attrs['I_high'] - f['data'].attrs['I_low']) / f['data'].attrs['stim_num']) * 1000
        start = np.int(f['data'].attrs['initial_period'] * 1 / f['data'].attrs['dt'])
        stim_len = np.int((f['data'].attrs['stim_time'] - start) * f['data'].attrs['dt'])
        time = np.arange(0, stim_len, f['data'].attrs['dt'])
        spiking[model_name] = f['data']['V_m'][spike_ind[model_name]][start:]
        model_F_inf[model_name] = f['analysis']['F_inf'][:]
        firing_values.loc['spike_ind', model_name] = I_mag[spike_ind[model_name]]
        firing_values.loc['ramp_down', model_name] = f['analysis']['ramp_I_down'][()]
        firing_values.loc['ramp_up', model_name] = f['analysis']['ramp_I_up'][()]
firing_values.to_csv('firing_values.csv')
spiking.to_csv('model_spiking.csv')
model_F_inf.to_csv('model_F_inf.csv')
# %% model_ramp.csv
# | (index) | t | models ....
import numpy as np
import pandas as pd

models = ['RS_pyramidal_Kv', 'RS_inhib_Kv', 'FS_Kv', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
          'STN_Kv', 'STN_Kv_only']
model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate', 'Cb stellate +Kv1.1', 'Cb stellate $\Delta$Kv1.1',
               'STN', 'STN +Kv1.1', 'STN $\Delta$Kv1.1']
col_names = ['t']
for mod in models: col_names.append(mod)
model_ramp = pd.DataFrame(columns=col_names)
sec = 4
dt = 0.01
ramp_len = int(sec * 1000 * 1 / dt)
t_ramp = np.arange(0, ramp_len) * dt
model_ramp.loc[:, 't'] = t_ramp

for model_name in models:
    folder = '../Neuron_models/{}'.format(model_name)
    fname = os.path.join(folder, "{}.hdf5".format(model_name))  # RS_inhib
    with h5py.File(fname, "r+") as f:
        model_ramp.loc[:, model_name] = f['analysis']['V_m_ramp'][()]

model_ramp.to_csv('model_ramp.csv')

# %% sim_mut_AUC.csv, sim_mut_rheo.csv
# generate mutation plot data

mutations = json.load(open("../mutations_effects_dict.json"))
keys_to_remove = ['V408L', 'T226R', 'R239S', 'R324T']
for key in keys_to_remove:
    del mutations[key]
mutations_f = []
mutations_n = []
for mut in mutations:
    mutations_n.append(mut)
    mutations_f.append(mut.replace(" ", "_"))

models = ['RS_pyramidal_Kv', 'RS_inhib_Kv', 'FS_Kv', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
          'STN_Kv', 'STN_Kv_only']
model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate', 'Cb stellate +Kv1.1', 'Cb stellate $\Delta$Kv1.1',
               'STN', 'STN +Kv1.1', 'STN $\Delta$Kv1.1']
AUC = pd.DataFrame(columns=mutations_n)
rheobase = pd.DataFrame(columns=mutations_n)
save_folder = '../KCNA1_mutations'
if not os.path.isdir(save_folder):
    os.makedirs(save_folder)
for model_name in models:
    folder = '../KCNA1_mutations/{}'.format(model_name)
    for mut in list(mutations_n):
        fname = os.path.join(folder, "{}.hdf5".format(mut.replace(" ", "_")))
        with h5py.File(fname, "r+") as f:
            rheobase.loc[mut.replace(" ", "_"), model_name] = f['analysis']['rheobase'][()]
            AUC.loc[mut.replace(" ", "_"), model_name] = f['analysis']['AUC'][()]
AUC.replace(0., np.NaN, inplace=True)
rheobase.replace(0., np.NaN, inplace=True)
rheobase = (rheobase - rheobase.loc['WT', :]) /rheobase.loc['WT', :]
AUC = (AUC - AUC.loc['WT', :]) /AUC.loc['WT', :]
AUC.to_csv(os.path.join(save_folder, 'sim_mut_AUC.csv'))
rheobase.to_csv(os.path.join(save_folder, 'sim_mut_rheobase.csv'))


#########################################################################################################################
#########################################################################################################################
#########################################################################################################################
#########################################################################################################################
#########################################################################################################################
#########################################################################################################################
#########################################################################################################################
#########################################################################################################################
#%% sim_mut_rheo.csv, sim_mut_AUC.csv
# models = ['RS_pyramidal', 'RS_inhib', 'FS', 'IB','Cb_stellate','Cb_stellate_Kv','Cb_stellate_Kv_only','STN','STN_Kv','STN_Kv_only']
# mutations = json.load(open("./mutations_effects_dict.json"))
# mutations2 = []
# for mut in mutations:
#     mutations2.append(mut.replace(" ", "_"))
# AUC_total = pd.DataFrame(columns=list(mutations2), index=models)
# AUC_rel_total = pd.DataFrame(columns=list(mutations2), index=models)
# rheo_total = pd.DataFrame(columns=list(mutations2), index=models)
# rheo_fit_total = pd.DataFrame(columns=list(mutations2), index=models)
# for mod in models:
#     print(mod)
#     with open('./mut_summary_df/{}_AUC.json'.format(mod)) as json_file:
#         df = pd.read_json(json_file, convert_dates=False, convert_axes=False)
#         # df[mutations2].to_json('./mut_summary_df/{}_AUC.json'.format(mod))
#         AUC_total.loc[mod, :] = df.loc['0',:]
#     with open('./mut_summary_df/{}_AUC_rel.json'.format(mod)) as json_file:
#         df = pd.read_json(json_file, convert_dates=False, convert_axes=False)
#         # df[mutations2].to_json('./mut_summary_df/{}_AUC.json'.format(mod))
#         AUC_rel_total.loc[mod, :] = df.loc['0',:]
#     with open('./mut_summary_df/{}_rheobase.json'.format(mod)) as json_file:
#         df = pd.read_json(json_file, convert_dates=False, convert_axes=False)
#         # df[mutations2].to_json('./mut_summary_df/{}_AUC.json'.format(mod))
#         rheo_total.loc[mod, :] = df.loc['0',:]
#     with open('./mut_summary_df/{}_rheobase_fit.json'.format(mod)) as json_file:
#         df = pd.read_json(json_file, convert_dates=False, convert_axes=False)
#         # df[mutations2].to_json('./mut_summary_df/{}_AUC.json'.format(mod))
#         rheo_fit_total.loc[mod, :] = df.loc['0',:]
#
# # AUC_diff = (AUC_score.subtract(AUC_score['wt'], axis =0))#.divide(AUC_score['wt'], axis=0)
# AUC_total.to_json('mutation_AUC_summary.json')
# AUC_rel_total.to_json('mutation_AUC_rel_summary.json')
# rheo_total.to_json('mutation_rheo_summary.json')
# rheo_fit_total.to_json('mutation_rheo_fit_summary.json')

# models = ['RS_pyramidal', 'RS_inhib', 'FS', 'IB','Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
#           'STN_Kv', 'STN_Kv_only']
# model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'IB', 'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
#                'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
#
# all_mut_rheo = pd.DataFrame(columns=[model_names])
# all_mut_AUC = pd.DataFrame(columns=[model_names])
# for mod in range(len(models)):
#     AUC = pd.read_json('./CBZ_summary_df/{}_CBZ_AUC_rel_acc.json'.format(models[mod]), convert_axes=False,
#                        convert_dates=False)
#     rheo = pd.read_json('./CBZ_summary_df/{}_CBZ_rheobase.json'.format(models[mod]), convert_axes=False,
#                         convert_dates=False)
#     AUC.index = AUC.index.astype(float)
#
#     rheo.index = rheo.index.astype(float)
#
#     conc = np.array(AUC.index)
#     mut_names = AUC.columns
#     for m in mut_names:
#         rheo[m] = rheo[m].map(lambda x: x[0])
#         AUC[m] = AUC[m].map(lambda x: x[0])
#     AUC.replace(0., np.NaN, inplace=True)
#     rheo.replace(0., np.NaN, inplace=True)
#     AUC = (AUC - AUC.loc[0, 'wt']) / AUC.loc[0, :]  # normalize AUC
#     AUC.sort_index(inplace=True)
#     rheo = (rheo - rheo.loc[0, 'wt'])#/ rheo.loc[0, :]  # normalize AUC
#     rheo.sort_index(inplace=True)
#
#
#     for mut in mut_names:  # for each mutation
#         x_mut = rheo.loc[:, mut.replace(" ", "_")]
#         y_mut = AUC.loc[:, mut.replace(" ", "_")]
#
#     all_mut_rheo[model_names[mod]] = rheo.loc[0.0, mut_names]
#     all_mut_AUC[model_names[mod]] = AUC.loc[0.0, mut_names]
#
#
#
# all_mut_rheo.to_csv('sim_mut_rheo.csv')
# all_mut_AUC.to_csv('sim_mut_AUC.csv')




# # %% model_F_inf.csv
# # | (index) | I | I_inhib | models ....
#
# # models
# models = ['RS_pyr', 'RS_pyr_Kv', 'RS_inhib', 'RS_inhib_Kv', 'FS', 'FS_Kv',
#           'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Delta_Kv',
#           'STN', 'STN_Kv', 'STN_Delta_Kv']
# col_names = ['I', 'I_inhib']
# for mod in models: col_names.append(mod)
# model_F_inf = pd.DataFrame(columns=col_names)
# folder = '../Neuron_models'
# with h5py.File(os.path.join(folder, "RS_pyr.hdf5"), "r+") as f:
#     model_F_inf['I'] = np.arange(f['data'].attrs['low'][()], f['data'].attrs['high'][()],
#                                  (f['data'].attrs['high'][()] - f['data'].attrs['low'][()]) /
#                                  f['data'].attrs['stim_num'][()])
# with h5py.File(os.path.join(folder, "RS_inhib.hdf5"), "r+") as f:
#     model_F_inf['I_inhib'] = np.arange(f['data'].attrs['low'][()], f['data'].attrs['high'][()],
#                                        (f['data'].attrs['high'][()] - f['data'].attrs['low'][()]) /
#                                        f['data'].attrs['stim_num'][()])
#
# for model_name in models:
#     fname = os.path.join(folder, "{}.hdf5".format(model_name))  # RS_inhib
#     with h5py.File(fname, "r+") as f:
#         model_F_inf[model_name] = f['analysis']['F_inf'][()]
#
# # save df
# model_F_inf.to_csv('model_F_inf.csv')