import argparse import os import numpy as np import matplotlib.pyplot as plt from scipy.stats import pearsonr from ModelFit import get_best_fit from Baseline import BaselineModel from FiCurve import FICurveModel from CellData import CellData from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus def main(): # parser = argparse.ArgumentParser() # parser.add_argument("dir", help="folder containing the cell folders with the fit results") # args = parser.parse_args() dir_path = "results/final_1/" # args.dir # if not os.path.isdir(dir_path): # print("Argument dir is not a directory.") # parser.print_usage() # exit(0) # sensitivity_analysis(dir_path, max_models=3) behaviour_keys = ["Burstiness", "coefficient_of_variation", "serial_correlation", "vector_strength", "f_inf_slope", "f_zero_slope", "baseline_frequency"] fits_info = get_fit_info(dir_path) total_fits = len(fits_info) for cell in sorted(fits_info.keys()): model_values = fits_info[cell][1] # if model_values["vector_strength"] < 0.4: # del fits_info[cell] # print("excluded because of vs") # # elif model_values["f_inf_slope"] / fits_info[cell][2]["f_inf_slope"] > 2: # del fits_info[cell] # print("f_inf bad") # # elif abs((model_values["baseline_frequency"] / fits_info[cell][2]["baseline_frequency"]) - 1) > 0.05: # del fits_info[cell] # print("baseline freq bad") # # elif fits_info[cell][2]["Burstiness"] == 0 or abs((model_values["Burstiness"] / fits_info[cell][2]["Burstiness"]) - 1) > 0.65: # del fits_info[cell] # print("burstiness bad") # plot_overview_plus_hist(fits_info) print("'good' fits of total fits: {} / {}".format(len(fits_info), total_fits)) errors = calculate_percent_errors(fits_info) create_boxplots(errors) labels, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=False) create_correlation_plot(labels, corr_values, corrected_p_values) labels, corr_values, corrected_p_values = parameter_correlations(fits_info) create_correlation_plot(labels, corr_values, corrected_p_values) create_parameter_distributions(get_parameter_values(fits_info)) cell_b, model_b = get_behaviour_values(fits_info) create_behaviour_distributions(cell_b, model_b) pass def get_fit_info(folder): fits_info = {} for item in os.listdir(folder): cell_folder = os.path.join(folder, item) results = get_best_fit(cell_folder, use_comparable_error=False) cell_behaviour, model_behaviour = results.get_behaviour_values() fits_info[item] = [results.get_final_parameters(), model_behaviour, cell_behaviour] return fits_info def calculate_percent_errors(fits_info): errors = {} for cell in sorted(fits_info.keys()): for behaviour in fits_info[cell][1].keys(): if behaviour not in errors.keys(): errors[behaviour] = [] if fits_info[cell][2][behaviour] == 0: if fits_info[cell][1][behaviour] == 0: errors[behaviour].append(0) else: print("Cannot calc % error if reference is 0") continue errors[behaviour].append((fits_info[cell][1][behaviour] - fits_info[cell][2][behaviour]) / fits_info[cell][2][behaviour]) return errors def get_parameter_values(fits_info): par_keys = sorted(["input_scaling", "delta_a", "mem_tau", "noise_strength", "refractory_period", "tau_a", "v_offset", "dend_tau"]) parameter_values = {} for cell in sorted(fits_info.keys()): for par in par_keys: if par not in parameter_values.keys(): parameter_values[par] = [] parameter_values[par].append(fits_info[cell][0][par]) return parameter_values def get_behaviour_values(fits_info): behaviour_values_cell = {} behaviour_values_model = {} for cell in sorted(fits_info.keys()): for behaviour in fits_info[cell][1].keys(): if behaviour not in behaviour_values_cell.keys(): behaviour_values_cell[behaviour] = [] behaviour_values_model[behaviour] = [] behaviour_values_model[behaviour].append(fits_info[cell][1][behaviour]) behaviour_values_cell[behaviour].append(fits_info[cell][2][behaviour]) return behaviour_values_cell, behaviour_values_model def behaviour_correlations(fits_info, model_values=True): bv_cell, bv_model = get_behaviour_values(fits_info) if model_values: behaviour_values = bv_model else: behaviour_values = bv_cell labels = sorted(behaviour_values.keys()) corr_values = np.zeros((len(labels), len(labels))) p_values = np.ones((len(labels), len(labels))) for i in range(len(labels)): for j in range(len(labels)): c, p = pearsonr(behaviour_values[labels[i]], behaviour_values[labels[j]]) corr_values[i, j] = c p_values[i, j] = p corrected_p_values = p_values * sum(range(len(labels))) return labels, corr_values, corrected_p_values def parameter_correlations(fits_info): parameter_values = get_parameter_values(fits_info) labels = sorted(parameter_values.keys()) corr_values = np.zeros((len(labels), len(labels))) p_values = np.ones((len(labels), len(labels))) for i in range(len(labels)): for j in range(len(labels)): c, p = pearsonr(parameter_values[labels[i]], parameter_values[labels[j]]) corr_values[i, j] = c p_values[i, j] = p corrected_p_values = p_values * sum(range(len(labels))) return labels, corr_values, corrected_p_values def create_correlation_plot(labels, correlations, p_values): cleaned_cors = np.zeros(correlations.shape) for i in range(correlations.shape[0]): for j in range(correlations.shape[1]): if abs(p_values[i, j]) < 0.05: cleaned_cors[i, j] = correlations[i, j] fig, ax = plt.subplots() im = ax.imshow(cleaned_cors, vmin=-1, vmax=1) cbar = ax.figure.colorbar(im, ax=ax) cbar.ax.set_ylabel("Correlation coefficient", rotation=-90, va="bottom") # We want to show all ticks... ax.set_xticks(np.arange(len(labels))) ax.set_yticks(np.arange(len(labels))) # ... and label them with the respective list entries ax.set_xticklabels(labels) ax.set_yticklabels(labels) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. for i in range(len(labels)): for j in range(len(labels)): text = ax.text(j, i, "{:.2f}".format(correlations[i, j]), ha="center", va="center", color="w") fig.tight_layout() plt.show() def create_boxplots(errors): labels = ["{}_n:{}".format(k, len(errors[k])) for k in sorted(errors.keys())] for k in sorted(errors.keys()): print("{}: median %-error: {:.2f}".format(k, np.median(errors[k]))) y_values = [errors[k] for k in sorted(errors.keys())] plt.boxplot(y_values) plt.xticks(np.arange(1, len(y_values)+1, 1), labels, rotation=45) plt.tight_layout() plt.show() plt.close() def plot_overview_plus_hist(fits_info): pairs = {} for cell in sorted(fits_info.keys()): for behaviour in fits_info[cell][1].keys(): if behaviour not in pairs.keys(): pairs[behaviour] = [[], []] # model_value pairs[behaviour][1].append(fits_info[cell][1][behaviour]) # cell value pairs[behaviour][0].append(fits_info[cell][2][behaviour]) for behaviour in pairs.keys(): error_overview_with_behaviour_dist(pairs[behaviour][0], pairs[behaviour][1], behaviour) def error_overview_with_behaviour_dist(x, y, title): # definitions for the axes left, width = 0.1, 0.65 bottom, height = 0.1, 0.65 spacing = 0.005 rect_scatter = [left, bottom, width, height] rect_histx = [left, bottom + height + spacing, width, 0.2] rect_histy = [left + width + spacing, bottom, 0.2, height] # start with a square Figure fig = plt.figure(figsize=(8, 8)) ax = fig.add_axes(rect_scatter) ax_histx = fig.add_axes(rect_histx, sharex=ax) ax_histy = None # fig.add_axes(rect_histy, sharey=ax) # use the previously defined function scatter_hist(x, y, ax, ax_histx, ax_histy) plt.title(title) plt.show() def scatter_hist(cell_values, model_values, ax, ax_histx, ax_histy): # copied from matplotlib # no labels ax_histx.tick_params(axis="cell_values", labelbottom=False) # ax_histy.tick_params(axis="model_values", labelleft=False) # the scatter plot: ax.scatter(cell_values, model_values) minimum = min(min(cell_values), min(model_values)) maximum = max(max(cell_values), max(model_values)) ax.plot((minimum, maximum), (minimum, maximum), color="grey") ax.set_xlabel("cell value") ax.set_ylabel("model value") # now determine nice limits by hand: binwidth = 0.25 xymax = max(np.max(np.abs(cell_values)), np.max(np.abs(model_values))) lim = (int(xymax/binwidth) + 1) * binwidth bins = np.arange(-lim, lim + binwidth, binwidth) ax_histx.hist(cell_values, color="blue", alpha=0.5) ax_histx.hist(model_values, color="orange", alpha=0.5) # ax_histy.hist(y, bins=bins, orientation='horizontal') def create_parameter_distributions(par_values): fig, axes = plt.subplots(4, 2) if len(par_values.keys()) != 8: print("not eight parameters") labels = sorted(par_values.keys()) axes_flat = axes.flatten() for i, l in enumerate(labels): min_v = min(par_values[l]) * 0.95 max_v = max(par_values[l]) * 1.05 step = (max_v - min_v) / 15 bins = np.arange(min_v, max_v+step, step) axes_flat[i].hist(par_values[l], bins=bins) axes_flat[i].set_title(l) plt.tight_layout() plt.show() plt.close() def create_behaviour_distributions(cell_b_values, model_b_values): fig, axes = plt.subplots(4, 2) labels = sorted(cell_b_values.keys()) axes_flat = axes.flatten() for i, l in enumerate(labels): min_v = min(min(cell_b_values[l]), min(model_b_values[l])) * 0.95 max_v = max(max(cell_b_values[l]), max(model_b_values[l])) * 1.05 limit = 50000 if max_v > limit: print("For {} the max value was limited to {}, {} values were excluded!".format(l, limit, np.sum(np.array(cell_b_values[l]) > limit))) max_v = limit step = (max_v - min_v) / 15 bins = np.arange(min_v, max_v + step, step) axes_flat[i].hist(cell_b_values[l], bins=bins, alpha=0.5) axes_flat[i].hist(model_b_values[l], bins=bins, alpha=0.5) axes_flat[i].set_title(l) plt.tight_layout() plt.show() plt.close() pass def sensitivity_analysis(dir_path, par_range=(0.5, 1.6, 0.1), contrast_range=(-0.3, 0.4, 0.1), parameters=None, behaviours=None, max_models=None): models = [] eods = [] base_freqs = [] count = 0 for item in sorted(os.listdir(dir_path)): count += 1 if max_models is not None and count > max_models: break cell_folder = os.path.join(dir_path, item) results = get_best_fit(cell_folder) models.append(results.get_model()) eods.append(CellData(results.get_cell_path()).get_eod_frequency()) cell, model = results.get_behaviour_values() base_freqs.append(cell["baseline_frequency"]) if parameters is None: parameters = ["input_scaling", "delta_a", "mem_tau", "noise_strength", "refractory_period", "tau_a", "dend_tau"] if behaviours is None: behaviours = ["burstiness", "coefficient_of_variation", "serial_correlation", "vector_strength", "f_inf_slope", "f_zero_slope", "f_zero_middle"] model_behaviour_responses = [] contrasts = np.arange(contrast_range[0], contrast_range[1], contrast_range[2]) factors = np.arange(par_range[0], par_range[1], par_range[2]) for model, eod, base_freq in zip(models, eods, base_freqs): par_responses = {} for par in parameters: par_responses[par] = {} for b in behaviours: par_responses[par][b] = np.zeros(len(factors)) for i, factor in enumerate(factors): model_copy = model.get_model_copy() model_copy.set_variable(par, model.get_parameters()[par] * factor) print("{} at {}, ({} of {})".format(par, model.get_parameters()[par] * factor, i+1, len(factors))) base_stimulus = SinusoidalStepStimulus(eod, 0) v_offset = model_copy.find_v_offset(base_freq, base_stimulus) model_copy.set_variable("v_offset", v_offset) baseline = BaselineModel(model_copy, eod, trials=3) print(baseline.get_baseline_frequency()) if "burstiness" in behaviours: par_responses[par]["burstiness"][i] = baseline.get_burstiness() if "coefficient_of_variation" in behaviours: par_responses[par]["coefficient_of_variation"][i] = baseline.get_coefficient_of_variation() if "serial_correlation" in behaviours: par_responses[par]["serial_correlation"][i] = baseline.get_serial_correlation(1)[0] if "vector_strength" in behaviours: par_responses[par]["vector_strength"][i] = baseline.get_vector_strength() fi_curve = FICurveModel(model_copy, contrasts, eod, trials=20) if "f_inf_slope" in behaviours: par_responses[par]["f_inf_slope"][i] = fi_curve.get_f_inf_slope() if "f_zero_slope" in behaviours: par_responses[par]["f_zero_slope"][i] = fi_curve.get_f_zero_fit_slope_at_straight() if "f_zero_middle" in behaviours: par_responses[par]["f_zero_middle"][i] = fi_curve.f_zero_fit[3] model_behaviour_responses.append(par_responses) print("sensitivity analysis done!") plot_sensitivity_analysis(model_behaviour_responses, behaviours, parameters, factors) def plot_sensitivity_analysis(responses, behaviours, parameters, factors): fig, axes = plt.subplots(len(behaviours), len(parameters), sharex="all", sharey="row", figsize=(8, 8)) for i, behaviour in enumerate(behaviours): for j, par in enumerate(parameters): for model in responses: axes[i, j].plot(factors, model[par][behaviour]) if j == 0: axes[i, j].set_ylabel("{}".format(behaviour)) if i == 0: axes[i, j].set_title("{}".format(par)) plt.tight_layout() plt.show() plt.close() if __name__ == '__main__': main()