From 3a4ece990271fe017607f499816ba881fcb2534d Mon Sep 17 00:00:00 2001 From: Till Raab Date: Mon, 19 Jun 2023 14:33:32 +0200 Subject: [PATCH] savety --- ethogram.py | 95 +++++++++++++++++++++++++++++++----------- event_time_analysis.py | 23 ++++++++-- 2 files changed, 90 insertions(+), 28 deletions(-) diff --git a/ethogram.py b/ethogram.py index 2b4d544..b6b0ee1 100644 --- a/ethogram.py +++ b/ethogram.py @@ -2,6 +2,7 @@ import os import sys import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec +from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np import pandas as pd import scipy.stats as scp @@ -13,21 +14,36 @@ glob_colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C', def plot_transition_matrix(matrix, labels): - fig, ax = plt.subplots() + fig = plt.figure(figsize=(20/2.54, 20/2.54)) + #gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.1, right=0.9, top=0.95, wspace=0.1, width_ratios=[8, 1]) + gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.925, top=0.95) + ax = fig.add_subplot(gs[0, 0]) + + divider = make_axes_locatable(ax) + cax = divider.append_axes('right', size='5%', pad=0.05) + + # cax = fig.add_subplot(gs[0, 1]) im = ax.imshow(matrix) ax.set_xticks(list(range(len(matrix)))) ax.set_yticks(list(range(len(matrix)))) - ax.set_xticklabels(labels) + ax.set_xticklabels(labels, rotation=45) ax.set_yticklabels(labels) - fig.colorbar(im) - plt.show() + + fig.colorbar(im, cax=cax, orientation='vertical') + + ax.tick_params(labelsize=10) + cax.tick_params(labelsize=10) + plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'event_counts' + '.png'), dpi=300) + plt.close() + + +def plot_transition_diagram(matrix, labels, node_size, ax, threshold=5, + color_by_origin=False, color_by_target=False, title=''): + -def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = 'rdm'): matrix[matrix <= threshold] = 0 matrix = np.around(matrix, decimals=1) - fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) - fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph) node_labels = dict(zip(Graph, labels)) @@ -42,24 +58,24 @@ def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = ' # ToDo: nodes nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)]) - nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels) + nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels, ax=ax) # google networkx drawing to get better graphs with networkx # nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax) # # ToDo: edges edge_width = np.array([x / 5 for x in [*edge_labels.values()]]) - edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]] + if color_by_origin: + edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]] + elif color_by_target: + edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 1]] + else: + edge_colors = 'k' edge_width[edge_width >= 6] = 6 - # nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width, - # arrows=True, arrowsize=20, - # min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0", - # ax=ax) - # nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.5, edge_labels=edge_labels, ax=ax, rotate=False) nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width, arrows=True, arrowsize=20, - min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.025", + min_target_margin=25, min_source_margin=25, connectionstyle="arc3, rad=0.025", ax=ax, edge_color=edge_colors) nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True) @@ -70,11 +86,11 @@ def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = ' ax.set_xlim(-1.3, 1.3) ax.set_ylim(-1.3, 1.3) - # plt.title(title) - plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', save_str + '.png'), dpi=300) - plt.close() - + ax.set_title(title, fontsize=12) def main(base_path): + if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')): + os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')) + trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0) chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0) # trial_summary = trial_summary[chirp_notes['good'] == 1] @@ -130,7 +146,9 @@ def main(base_path): event_times = np.array(event_times)[time_sorter] event_labels = np.array(event_labels)[time_sorter] + ### create marcov_matrix 1: which beh 2 is triggered by beh. 1 ? marcov_matrix = np.zeros((len(loop_labels)+1, len(loop_labels)+1)) + ### create marcov_matrix 2: beh 2 is triggered by which beh. 1 ? for enu_ori, label_ori in enumerate(loop_labels): for enu_tar, label_tar in enumerate(loop_labels): @@ -164,14 +182,43 @@ def main(base_path): collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0) collective_event_counts = np.sum(all_event_counts, axis=0) + plot_transition_matrix(collective_marcov_matrix, loop_labels) + + fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) + fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) + plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, - loop_labels, collective_event_counts, threshold=5, save_str='markov_all') + loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]') + plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300) + plt.close() + + fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) + fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) + plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100, + loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]') + plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300) + plt.close() + + for i, (marcov_matrix, event_counts) in enumerate(zip(all_marcov_matrix, all_event_counts)): + fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) + fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) + + plot_transition_diagram( + marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100, + loop_labels, event_counts, ax, threshold=5, color_by_origin=True, + title='origin triggers target [%]') + plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'), dpi=300) + plt.close() + + fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) + fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) + plot_transition_diagram(marcov_matrix / event_counts * 100, + loop_labels, event_counts, ax, threshold=5, color_by_target=True, + title='target triggered by origin [%]') + plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'), dpi=300) + plt.close() - # for i in range(len(all_marcov_matrix)): - # plot_transition_diagram( - # all_marcov_matrix[i] / all_event_counts[i].reshape(len(all_event_counts[i]), 1) * 100, - # loop_labels, all_event_counts[i], threshold=5) embed() quit() pass diff --git a/event_time_analysis.py b/event_time_analysis.py index 4cf2595..0fcdf2f 100644 --- a/event_time_analysis.py +++ b/event_time_analysis.py @@ -13,9 +13,23 @@ from event_time_correlations import load_and_converete_boris_events, kde, gauss female_color, male_color = '#e74c3c', '#3498db' def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''): + # ToDo: finish this !!! iei = [] + weighted_mean_iei = [] + median_iei = [] for i in range(len(event_times)): - iei.append(np.diff(event_times[i])) + trial_iei = np.diff(event_times[i][event_times[i] <= 3600*3]) + iei.append(trial_iei) + + if len(trial_iei) == 0: + weighted_mean_iei.append(np.nan) + median_iei.append(np.nan) + else: + weighted_mean_iei.append(np.sum((trial_iei) * trial_iei) / np.sum(trial_iei)) + median_iei.append(np.median(trial_iei)) + + weighted_mean_iei = np.array(weighted_mean_iei) + median_iei = np.array(median_iei) fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54)) gs = gridspec.GridSpec(2, 2, left=0.1, bottom=0.1, right=0.95, top=0.9) @@ -41,13 +55,16 @@ def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''): color, linestyle = female_color, '-' sp = 3 - conv_y = np.arange(0, np.percentile(np.hstack(iei), 80), .5) kde_array = kde(iei[i], conv_y, kernal_w=kernal_w, kernal_h=1) # kde_array /= np.sum(kde_array) ax[sp].plot(conv_y, kde_array, zorder=2, color=color, linestyle=linestyle, lw=2) + # ax_m = ax[0].twinx() + # ax_m.boxplot([weighted_mean_iei[(win_sex == 'm') & (win_sex == 'm') & ~np.isnan(weighted_mean_iei)], + # median_iei[(win_sex == 'm') & (win_sex == 'm') & ~np.isnan(median_iei)]], sym='', vert=False) + ax[0].set_xlim(conv_y[0], conv_y[-1]) ax[0].set_ylabel('KDE', fontsize=12) ax[2].set_ylabel('KDE', fontsize=12) @@ -170,7 +187,6 @@ def relative_rate_progression(all_event_t, title=''): def main(base_path): - # ToDo: for chirp and rise analysis different datasets!!! if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta')): os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta')) @@ -178,7 +194,6 @@ def main(base_path): os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr')) - trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0) chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0) trial_mask = chirp_notes['good'] == 1