import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from analysis import get_filtered_fit_info, get_behaviour_values, get_parameter_values, behaviour_correlations, parameter_correlations from ModelFit import get_best_fit from Baseline import BaselineModel, BaselineCellData import Figure_constants as consts parameter_titles = {"input_scaling": r"$\alpha$", "delta_a": r"$\Delta_A$", "mem_tau": r"$\tau_m$", "noise_strength": r"$\sqrt{2D}$", "refractory_period": "$t_{ref}$", "tau_a": r"$\tau_A$", "v_offset": r"$I_{Bias}$", "dend_tau": r"$\tau_{dend}$"} behaviour_titles = {"baseline_frequency": "base rate", "Burstiness": "Burst", "coefficient_of_variation": "CV", "serial_correlation": "SC", "vector_strength": "VS", "f_inf_slope": r"$f_{\infty}$ Slope", "f_zero_slope": r"$f_0$ Slope", "f_zero_middle": r"$f_0$ middle"} def main(): dir_path = "results/final_2/" fits_info = get_filtered_fit_info(dir_path) print("Cells left:", len(fits_info)) cell_behaviour, model_behaviour = get_behaviour_values(fits_info) plot_cell_model_comp_baseline(cell_behaviour, model_behaviour) plot_cell_model_comp_adaption(cell_behaviour, model_behaviour) plot_cell_model_comp_burstiness(cell_behaviour, model_behaviour) # behaviour_correlations_plot(fits_info) # labels, corr_values, corrected_p_values = parameter_correlations(fits_info) par_labels = [parameter_titles[l] for l in labels] fig, ax = plt.subplots(1, 1) #ax, labels, correlations, p_values, title, y_label=True create_correlation_plot(ax, par_labels, corr_values, corrected_p_values, "") plt.savefig(consts.SAVE_FOLDER + "parameter_correlations.png") plt.close() # create_parameter_distributions(get_parameter_values(fits_info)) # errors = calculate_percent_errors(fits_info) # create_boxplots(errors) # example_good_hist_fits(dir_path) def create_parameter_distributions(par_values): fig, axes = plt.subplots(4, 2) if len(par_values.keys()) != 8: print("not eight parameters") labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength", "tau_a", "delta_a", "dend_tau", "refractory_period"] axes_flat = axes.flatten() for i, l in enumerate(labels): min_v = min(par_values[l]) * 0.95 max_v = max(par_values[l]) * 1.05 step = (max_v - min_v) / 20 bins = np.arange(min_v, max_v+step, step) axes_flat[i].hist(par_values[l], bins=bins, color=consts.COLOR_MODEL, alpha=0.75) axes_flat[i].set_title(parameter_titles[l]) plt.tight_layout() plt.savefig(consts.SAVE_FOLDER + "parameter_distributions.png") plt.close() def behaviour_correlations_plot(fits_info): fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE) gs = gridspec.GridSpec(2, 2, width_ratios=(1, 1), height_ratios=(5, 1), hspace=0.025, wspace=0.05) # fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_MEDIUM_WIDE) keys, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=False) labels = [behaviour_titles[k] for k in keys] img = create_correlation_plot(fig.add_subplot(gs[0, 0]), labels, corr_values, corrected_p_values, "Data") keys, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=True) labels = [behaviour_titles[k] for k in keys] img = create_correlation_plot(fig.add_subplot(gs[0, 1]), labels, corr_values, corrected_p_values, "Model", y_label=False) ax_col = fig.add_subplot(gs[1, :]) data = [np.arange(-1, 1.001, 0.01)] * 10 ax_col.set_xticks([0, 25, 50, 75, 100, 125, 150, 175, 200]) ax_col.set_xticklabels([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1]) ax_col.set_yticks([]) ax_col.imshow(data) ax_col.set_xlabel("Correlation Coefficients") plt.tight_layout() plt.savefig(consts.SAVE_FOLDER + "behaviour_correlations.png") plt.close() def create_correlation_plot(ax, labels, correlations, p_values, title, y_label=True): cleaned_cors = np.zeros(correlations.shape) for i in range(correlations.shape[0]): for j in range(correlations.shape[1]): if abs(p_values[i, j]) < 0.05: cleaned_cors[i, j] = correlations[i, j] else: cleaned_cors[i, j] = np.NAN im = ax.imshow(cleaned_cors, vmin=-1, vmax=1) # We want to show all ticks... ax.set_xticks(np.arange(len(labels))) ax.set_xticklabels(labels) # ... and label them with the respective list entries if y_label: ax.set_yticks(np.arange(len(labels))) ax.set_yticklabels(labels) else: ax.set_yticklabels([]) ax.set_title(title) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. for i in range(len(labels)): for j in range(len(labels)): if p_values[i][j] < 0.0001: text = ax.text(j, i, "***", ha="center", va="center", color="b") elif p_values[i][j] < 0.001: text = ax.text(j, i, "**", ha="center", va="center", color="b") elif p_values[i][j] < 0.05: text = ax.text(j, i, "*", ha="center", va="center", color="b") return im def example_good_hist_fits(dir_path): strong_bursty_cell = "2018-05-08-ac-invivo-1" bursty_cell = "2014-03-19-ad-invivo-1" non_bursty_cell = "2012-12-21-am-invivo-1" fig, axes = plt.subplots(1, 3, sharex="all", figsize=consts.FIG_SIZE_MEDIUM_WIDE) for i, cell in enumerate([non_bursty_cell, bursty_cell, strong_bursty_cell]): fit_dir = dir_path + cell + "/" fit = get_best_fit(fit_dir) cell_data = fit.get_cell_data() eodf = cell_data.get_eod_frequency() model = fit.get_model() baseline_model = BaselineModel(model, eodf, trials=5) model_isi = np.array(baseline_model.get_interspike_intervals()) * eodf cell_isi = BaselineCellData(cell_data).get_interspike_intervals() * eodf bins = np.arange(0, 0.025, 0.0001) * eodf axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA) axes[i].hist(model_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_MODEL) axes[i].set_xlabel("ISI in EOD periods") axes[0].set_ylabel("Density") plt.tight_layout() consts.set_figure_labels(xoffset=-2.5) fig.label_axes() plt.savefig(consts.SAVE_FOLDER + "example_good_isi_hist_fits.png", transparent=True) plt.close() def create_boxplots(errors): labels = ["{}_n:{}".format(k, len(errors[k])) for k in sorted(errors.keys())] for k in sorted(errors.keys()): print("{}: median %-error: {:.2f}".format(k, np.median(errors[k]))) y_values = [errors[k] for k in sorted(errors.keys())] plt.boxplot(y_values) plt.xticks(np.arange(1, len(y_values)+1, 1), labels, rotation=45) plt.tight_layout() plt.show() plt.close() def plot_cell_model_comp_baseline(cell_behavior, model_behaviour): fig = plt.figure(figsize=(12, 6)) # ("baseline_frequency", "vector_strength", "serial_correlation") # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between # the size of the marginal axes and the main axes in both directions. # Also adjust the subplot parameters for a square plot. gs = fig.add_gridspec(2, 3, width_ratios=[5, 5, 5], height_ratios=[3, 7], left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.25, hspace=0.05) num_of_bins = 20 # baseline freq plot: i = 0 cell = cell_behavior["baseline_frequency"] model = model_behaviour["baseline_frequency"] minimum = min(min(cell), min(model)) maximum = max(max(cell), max(model)) step = (maximum - minimum) / num_of_bins bins = np.arange(minimum, maximum + step, step) ax = fig.add_subplot(gs[1, i]) ax_histx = fig.add_subplot(gs[0, i], sharex=ax) scatter_hist(cell, model, ax, ax_histx, behaviour_titles["baseline_frequency"], bins) i += 1 cell = cell_behavior["vector_strength"] model = model_behaviour["vector_strength"] minimum = min(min(cell), min(model)) maximum = max(max(cell), max(model)) step = (maximum - minimum) / num_of_bins bins = np.arange(minimum, maximum + step, step) ax = fig.add_subplot(gs[1, i]) ax_histx = fig.add_subplot(gs[0, i], sharex=ax) scatter_hist(cell, model, ax, ax_histx, behaviour_titles["vector_strength"], bins) i += 1 cell = cell_behavior["serial_correlation"] model = model_behaviour["serial_correlation"] minimum = min(min(cell), min(model)) maximum = max(max(cell), max(model)) step = (maximum - minimum) / num_of_bins bins = np.arange(minimum, maximum + step, step) ax = fig.add_subplot(gs[1, i]) ax_histx = fig.add_subplot(gs[0, i], sharex=ax) scatter_hist(cell, model, ax, ax_histx, behaviour_titles["serial_correlation"], bins) i += 1 plt.tight_layout() plt.savefig(consts.SAVE_FOLDER + "fit_baseline_comparison.png", transparent=True) plt.close() def plot_cell_model_comp_adaption(cell_behavior, model_behaviour): fig = plt.figure(figsize=(8, 6)) # ("f_inf_slope", "f_zero_slope") # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between # the size of the marginal axes and the main axes in both directions. # Also adjust the subplot parameters for a square plot. gs = fig.add_gridspec(2, 2, width_ratios=[5, 5], height_ratios=[3, 7], left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.25, hspace=0.05) num_of_bins = 20 # baseline freq plot: i = 0 cell = cell_behavior["f_inf_slope"] model = model_behaviour["f_inf_slope"] minimum = min(min(cell), min(model)) maximum = max(max(cell), max(model)) step = (maximum - minimum) / num_of_bins bins = np.arange(minimum, maximum + step, step) ax = fig.add_subplot(gs[1, i]) ax_histx = fig.add_subplot(gs[0, i], sharex=ax) scatter_hist(cell, model, ax, ax_histx, behaviour_titles["f_inf_slope"], bins) i += 1 cell = cell_behavior["f_zero_slope"] model = model_behaviour["f_zero_slope"] length_before = len(cell) idx = np.array(cell) < 25000 cell = np.array(cell)[idx] model = np.array(model)[idx] idx = np.array(model) < 25000 cell = np.array(cell)[idx] model = np.array(model)[idx] print("removed {} values from f_zero_slope plot.".format(length_before - len(cell))) minimum = min(min(cell), min(model)) maximum = max(max(cell), max(model)) step = (maximum - minimum) / num_of_bins bins = np.arange(minimum, maximum + step, step) ax = fig.add_subplot(gs[1, i]) ax_histx = fig.add_subplot(gs[0, i], sharex=ax) scatter_hist(cell, model, ax, ax_histx, behaviour_titles["f_zero_slope"], bins) plt.tight_layout() plt.savefig(consts.SAVE_FOLDER + "fit_adaption_comparison.png", transparent=True) plt.close() def plot_cell_model_comp_burstiness(cell_behavior, model_behaviour): fig = plt.figure(figsize=(8, 6)) # ("Burstiness", "coefficient_of_variation") # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between # the size of the marginal axes and the main axes in both directions. # Also adjust the subplot parameters for a square plot. gs = fig.add_gridspec(2, 2, width_ratios=[5, 5], height_ratios=[3, 7], left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.25, hspace=0.05) num_of_bins = 20 # baseline freq plot: i = 0 cell = cell_behavior["Burstiness"] model = model_behaviour["Burstiness"] minimum = min(min(cell), min(model)) maximum = max(max(cell), max(model)) step = (maximum - minimum) / num_of_bins bins = np.arange(minimum, maximum + step, step) ax = fig.add_subplot(gs[1, i]) ax_histx = fig.add_subplot(gs[0, i], sharex=ax) scatter_hist(cell, model, ax, ax_histx, behaviour_titles["Burstiness"], bins) i += 1 cell = cell_behavior["coefficient_of_variation"] model = model_behaviour["coefficient_of_variation"] minimum = min(min(cell), min(model)) maximum = max(max(cell), max(model)) step = (maximum - minimum) / num_of_bins bins = np.arange(minimum, maximum + step, step) ax = fig.add_subplot(gs[1, i]) ax_histx = fig.add_subplot(gs[0, i], sharex=ax) scatter_hist(cell, model, ax, ax_histx, behaviour_titles["coefficient_of_variation"], bins) plt.tight_layout() plt.savefig(consts.SAVE_FOLDER + "fit_burstiness_comparison.png", transparent=True) plt.close() # def behaviour_overview_pairs(cell_behaviour, model_behaviour): # # behaviour_keys = ["Burstiness", "coefficient_of_variation", "serial_correlation", # # "vector_strength", "f_inf_slope", "f_zero_slope", "baseline_frequency"] # # pairs = [("baseline_frequency", "vector_strength", "serial_correlation"), # ("Burstiness", "coefficient_of_variation"), # ("f_inf_slope", "f_zero_slope")] # # for pair in pairs: # cell = [] # model = [] # for behaviour in pair: # cell.append(cell_behaviour[behaviour]) # model.append(model_behaviour[behaviour]) # overview_pair(cell, model, pair) # # # def overview_pair(cell, model, titles): # fig = plt.figure(figsize=(8, 6)) # # columns = len(cell) # # # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between # # the size of the marginal axes and the main axes in both directions. # # Also adjust the subplot parameters for a square plot. # gs = fig.add_gridspec(2, columns, width_ratios=[5] * columns, height_ratios=[3, 7], # left=0.1, right=0.9, bottom=0.1, top=0.9, # wspace=0.2, hspace=0.05) # # for i in range(len(cell)): # if titles[i] == "f_zero_slope": # length_before = len(cell[i]) # idx = np.array(cell[i]) < 30000 # cell[i] = np.array(cell[i])[idx] # model[i] = np.array(model[i])[idx] # # idx = np.array(model[i]) < 30000 # cell[i] = np.array(cell[i])[idx] # model[i] = np.array(model[i])[idx] # print("removed {} values from f_zero_slope plot.".format(length_before - len(cell[i]))) # ax = fig.add_subplot(gs[1, i]) # ax_histx = fig.add_subplot(gs[0, i], sharex=ax) # scatter_hist(cell[i], model[i], ax, ax_histx, titles[i]) # # # plt.tight_layout() # plt.show() # # # def grouped_error_overview_behaviour_dist(cell_behaviours, model_behaviours): # # start with a square Figure # fig = plt.figure(figsize=(12, 12)) # # rows = 4 # columns = 2 # # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between # # the size of the marginal axes and the main axes in both directions. # # Also adjust the subplot parameters for a square plot. # gs = fig.add_gridspec(rows*2, columns, width_ratios=[5]*columns, height_ratios=[3, 7] * rows, # left=0.1, right=0.9, bottom=0.1, top=0.9, # wspace=0.2, hspace=0.5) # # for i, behaviour in enumerate(sorted(cell_behaviours.keys())): # col = int(np.floor(i / rows)) # row = i - rows*col # ax = fig.add_subplot(gs[row*2 + 1, col]) # ax_histx = fig.add_subplot(gs[row*2, col]) # # # use the previously defined function # scatter_hist(cell_behaviours[behaviour], model_behaviours[behaviour], ax, ax_histx, behaviour) # # plt.tight_layout() # plt.show() # def scatter_hist(cell_values, model_values, ax, ax_histx, behaviour, bins, ax_histy=None): # copied from matplotlib # no labels ax_histx.tick_params(axis="cell", labelbottom=False) # ax_histy.tick_params(axis="model_values", labelleft=False) # the scatter plot: minimum = min(min(cell_values), min(model_values)) maximum = max(max(cell_values), max(model_values)) ax.plot((minimum, maximum), (minimum, maximum), color="grey") ax.scatter(cell_values, model_values, color="black") ax.set_xlabel("cell") ax.set_ylabel("model") ax_histx.hist(cell_values, bins=bins, color=consts.COLOR_DATA, alpha=0.75) ax_histx.hist(model_values, bins=bins, color=consts.COLOR_MODEL, alpha=0.75) ax_labels = ax.get_xticklabels() ax_histx.set_xticklabels([]) ax.set_xticklabels(ax_labels) ax_histx.set_xticks(ax.get_xticks()) ax_histx.set_xlim(ax.get_xlim()) ax_histx.set_title(behaviour) # ax_histy.hist(y, bins=bins, orientation='horizontal') if __name__ == '__main__': main()