diff --git a/ethogram.py b/ethogram.py index 72d97bb..25acc34 100644 --- a/ethogram.py +++ b/ethogram.py @@ -2,6 +2,8 @@ import os import sys import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec +from matplotlib.patches import Patch +from matplotlib.lines import Line2D from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np import pandas as pd @@ -87,6 +89,7 @@ def plot_transition_diagram(matrix, labels, node_size, ax, threshold=5, ax.set_xlim(-1.3, 1.3) ax.set_ylim(-1.3, 1.3) ax.set_title(title, fontsize=12) + def main(base_path): if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')): os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')) @@ -103,10 +106,20 @@ def main(base_path): # agonistic categorie plot fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54)) - gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.95, top=0.95) - ax = fig.add_subplot(gs[0, 0]) + gs = gridspec.GridSpec(2, 1, left=0.1, bottom=0.1, right=0.9, top=0.95, height_ratios=[1, 4], hspace=0) + ax = fig.add_subplot(gs[1, 0]) + ax_spec = fig.add_subplot(gs[0, 0], sharex=ax) + plt.setp(ax_spec.get_xticklabels(), visible=False) + + for i in range(1, 5): + ax.fill_between([0, 4], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey') + ax.fill_between([5, 10], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey') + + fill_dots = np.arange(4, 5.1, 0.125) + ax.plot(fill_dots, np.ones_like(fill_dots)*i, '.', color='tab:grey', markersize=3) + got_examples = [False, False, False, False] - example_skips = [3, 4, 1, 0] + example_skips = [3, 4, 3, 0] for index, trial in trial_summary.iterrows(): trial_path = os.path.join(base_path, trial['recording']) @@ -209,7 +222,8 @@ def main(base_path): rise_before = True if np.any( ((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)): - chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)] + # chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)] + chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chase_dur) & ((chirp_times[1] - chase_off_time) < max_dt)] chirp_arround_end = True if rise_before: @@ -227,9 +241,6 @@ def main(base_path): if chase_dur > 10: if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0): if example_skips[int(agonitic_categorie[enu] - 1)] == 0: - ax.fill_between([0, 10], - np.array([-.2, -.2]) + agonitic_categorie[enu], - np.array([.2, .2]) + agonitic_categorie[enu], color='tab:grey') for ct in chirp_time_oi: ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10], [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='k', lw=2) @@ -239,73 +250,109 @@ def main(base_path): got_examples[0] = True else: example_skips[int(agonitic_categorie[enu] - 1)] -= 1 - elif agonitic_categorie[enu] == 2 and not got_examples[1]: if chase_dur > 10: if example_skips[int(agonitic_categorie[enu] - 1)] == 0: - ax.fill_between([0, 10], - np.array([-.2, -.2]) + agonitic_categorie[enu], - np.array([.2, .2]) + agonitic_categorie[enu], color='tab:grey') for rt in rise_times_oi: ax.plot([rt - chase_on_time, rt - chase_on_time], [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='firebrick', lw=2) got_examples[1] = True else: example_skips[int(agonitic_categorie[enu] - 1)] -= 1 - elif agonitic_categorie[enu] == 3 and not got_examples[2]: if chase_dur > 10: if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0): if example_skips[int(agonitic_categorie[enu] - 1)] == 0: - ax.fill_between([0, 10], - np.array([-.2, -.2]) + agonitic_categorie[enu], - np.array([.2, .2]) + agonitic_categorie[enu], color='tab:grey') for ct in chirp_time_oi: ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10], [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='k', lw=2) got_examples[2] = True else: example_skips[int(agonitic_categorie[enu] - 1)] -= 1 - - - elif agonitic_categorie[enu] == 4 and not got_examples[3]: - if chase_dur > 10: - ax.fill_between([0, 10], - np.array([-.2, -.2]) + agonitic_categorie[enu], - np.array([.2, .2]) + agonitic_categorie[enu], color='tab:grey') - got_examples[3] = True else: pass - for i in range(4): + ### agonistic categories + stacked_agonistic_categories = np.hstack(all_agonistic_categorie) + stacked_all_chase_durs = np.hstack(all_chase_durs) - ax.plot([0, 0], [0, 5], '--', color='k', lw=1) - ax.plot([10, 10], [0, 5], '--', color='k', lw=1) - ax.set_ylim(0.5, 4.5) + pct_each_categorie = np.zeros(4) + for enu, cat in enumerate(range(1, 5)): + pct_each_categorie[enu] = len(stacked_agonistic_categories[stacked_agonistic_categories == cat]) / len(stacked_agonistic_categories) + + # example plot + for enu, cat_pct in enumerate(pct_each_categorie): + ax.text(15.2, enu+1, f'{cat_pct*100:.1f}' + ' $\%$', clip_on=False, fontsize=14, ha='left', va='center') + + + ax.plot([0, 0], [0.8, 5], '--', color='k', lw=1) + ax.plot([10, 10], [0.8, 5], '--', color='k', lw=1) + ax.set_ylim(0.25, 4.5) + ax.set_xlim(-5, 15) + ax.set_yticks([1, 2, 3, 4]) + # ax.set_yticklabels([r'rise$_{pre}$ $&$ chirp$_{end}$', r'only rise$_{pre}$', r'only chirp$_{end}$', 'no communication']) + ax.set_yticklabels(['A ', 'B ', 'C ', 'D ']) + ax.invert_yaxis() + ax.set_xlabel('time [s]', fontsize=12) + ax.tick_params(axis='y', labelsize=20) + ax.tick_params(axis = 'x', labelsize=10) + + legend_elements = [Line2D([0], [0], color='firebrick', lw=2, label=r'rise$_{lose}$'), + Line2D([0], [0], color='k', lw=2, label=r'chirp$_{lose}$'), + Patch(facecolor='tab:grey', edgecolor='w', label= 'chase event')] + + ax.legend(handles=legend_elements, loc='upper right', ncol=3, bbox_to_anchor=(1, 1), frameon=False, fontsize=10, facecolor='white') + ax.spines[['right', 'top']].set_visible(False) plt.show() - embed() - quit() + + # bar plot + fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54)) + ax.bar(np.arange(4), + [len(stacked_agonistic_categories[stacked_agonistic_categories == 1]), + len(stacked_agonistic_categories[stacked_agonistic_categories == 2]), + len(stacked_agonistic_categories[stacked_agonistic_categories == 3]), + len(stacked_agonistic_categories[stacked_agonistic_categories == 4])]) + ax.set_xticks(np.arange(4)) + ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _']) + plt.show() + + # pct + pct_agon_categorie = np.zeros(shape=(len(all_agonistic_categorie), 4)) + for enu, agonitic_categorie in enumerate(all_agonistic_categorie): + for cat in np.arange(4): + pct_agon_categorie[enu, cat] = len(agonitic_categorie[agonitic_categorie == cat+1]) / len(agonitic_categorie) + + fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54)) + ax.bar(np.arange(4), pct_agon_categorie.mean(0)) + ax.errorbar(np.arange(4), pct_agon_categorie.mean(0), yerr=pct_agon_categorie.std(0), fmt='', color='k', linestyle='None') + ax.set_xticks(np.arange(4)) + ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _']) + plt.show() + + + ### marcov models all_marcov_matrix = np.array(all_marcov_matrix) all_event_counts = np.array(all_event_counts) collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0) collective_event_counts = np.sum(all_event_counts, axis=0) - plot_transition_matrix(collective_marcov_matrix, loop_labels) fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) - plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, - loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]') + plot_transition_diagram( + collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, + loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]') plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300) plt.close() fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100, - loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]') + loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True, + title='target triggered by origin [%]') plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300) plt.close() @@ -317,7 +364,8 @@ def main(base_path): marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100, loop_labels, event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]') - plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'), dpi=300) + plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'), + dpi=300) plt.close() fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) @@ -325,58 +373,12 @@ def main(base_path): plot_transition_diagram(marcov_matrix / event_counts * 100, loop_labels, event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]') - plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'), dpi=300) + plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'), + dpi=300) plt.close() embed() quit() - ### agonistic categories - - # stacked - stacked_agonistic_categories = np.hstack(all_agonistic_categorie) - stacked_all_chase_durs = np.hstack(all_chase_durs) - - # idx_cat_4 = np.where(stacked_agonistic_categories == 4)[0] - # idx_cat_4 = idx_cat_4[np.argsort(stacked_all_chase_durs[idx_cat_4])] - # idx_cat_3 = np.where(stacked_agonistic_categories == 3)[0] - # idx_cat_3 = idx_cat_3[np.argsort(stacked_all_chase_durs[idx_cat_3])] - # idx_cat_2 = np.where(stacked_agonistic_categories == 2)[0] - # idx_cat_2 = idx_cat_2[np.argsort(stacked_all_chase_durs[idx_cat_2])] - # idx_cat_1 = np.where(stacked_agonistic_categories == 1)[0] - # idx_cat_1 = idx_cat_1[np.argsort(stacked_all_chase_durs[idx_cat_1])] - # - # fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54)) - # ax.plot(10 - stacked_all_chase_durs[idx_cat_4], np.arange(len(idx_cat_4))) - # ax.plot(10 - stacked_all_chase_durs[idx_cat_3], np.arange(len(idx_cat_4), len(idx_cat_3) + len(idx_cat_4))) - # ax.plot(10 - stacked_all_chase_durs[idx_cat_2], np.arange(len(idx_cat_3) + len(idx_cat_4), len(idx_cat_2) + len(idx_cat_3) + len(idx_cat_4))) - # ax.plot(10 - stacked_all_chase_durs[idx_cat_1], np.arange(len(idx_cat_2) + len(idx_cat_3) + len(idx_cat_4), len(idx_cat_1) + len(idx_cat_2) + len(idx_cat_3) + len(idx_cat_4))) - # ax.set_xlim(0, 10) - # plt.show() - - - fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54)) - ax.bar(np.arange(4), - [len(stacked_agonistic_categories[stacked_agonistic_categories == 1]), - len(stacked_agonistic_categories[stacked_agonistic_categories == 2]), - len(stacked_agonistic_categories[stacked_agonistic_categories == 3]), - len(stacked_agonistic_categories[stacked_agonistic_categories == 4])]) - ax.set_xticks(np.arange(4)) - ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _']) - plt.show() - - # pct - pct_agon_categorie = np.zeros(shape=(len(all_agonistic_categorie), 4)) - for enu, agonitic_categorie in enumerate(all_agonistic_categorie): - for cat in np.arange(4): - pct_agon_categorie[enu, cat] = len(agonitic_categorie[agonitic_categorie == cat+1]) / len(agonitic_categorie) - - fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54)) - ax.bar(np.arange(4), pct_agon_categorie.mean(0)) - ax.errorbar(np.arange(4), pct_agon_categorie.mean(0), yerr=pct_agon_categorie.std(0), fmt='', color='k', linestyle='None') - ax.set_xticks(np.arange(4)) - ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _']) - plt.show() - pass