import os import sys import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from matplotlib.patches import Patch from matplotlib.lines import Line2D from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np import pandas as pd import scipy.stats as scp import networkx as nx from IPython import embed from event_time_correlations import load_and_converete_boris_events glob_colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C', '#1E90FF', 'k'] def plot_transition_matrix(matrix, labels): fig = plt.figure(figsize=(20/2.54, 20/2.54)) #gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.1, right=0.9, top=0.95, wspace=0.1, width_ratios=[8, 1]) gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.925, top=0.95) ax = fig.add_subplot(gs[0, 0]) divider = make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.05) # cax = fig.add_subplot(gs[0, 1]) im = ax.imshow(matrix) ax.set_xticks(list(range(len(matrix)))) ax.set_yticks(list(range(len(matrix)))) ax.set_xticklabels(labels, rotation=45) ax.set_yticklabels(labels) fig.colorbar(im, cax=cax, orientation='vertical') ax.tick_params(labelsize=10) cax.tick_params(labelsize=10) plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'event_counts' + '.png'), dpi=300) plt.close() def plot_transition_diagram(matrix, labels, node_size, ax, threshold=5, color_by_origin=False, color_by_target=False, title=''): matrix[matrix <= threshold] = 0 matrix = np.around(matrix, decimals=1) Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph) node_labels = dict(zip(Graph, labels)) # Graph = nx.relabel_nodes(Graph, node_labels) edge_labels = nx.get_edge_attributes(Graph, 'weight') positions = nx.circular_layout(Graph) positions2 = nx.circular_layout(Graph) for p in positions: positions2[p][0] *= 1.2 positions2[p][1] *= 1.2 # ToDo: nodes nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)]) nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels, ax=ax) # google networkx drawing to get better graphs with networkx # nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax) # # ToDo: edges edge_width = np.array([x / 5 for x in [*edge_labels.values()]]) if color_by_origin: edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]] elif color_by_target: edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 1]] else: edge_colors = 'k' edge_width[edge_width >= 6] = 6 nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width, arrows=True, arrowsize=20, min_target_margin=25, min_source_margin=25, connectionstyle="arc3, rad=0.025", ax=ax, edge_color=edge_colors) nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True) ax.spines["top"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) ax.set_xlim(-1.3, 1.3) ax.set_ylim(-1.3, 1.3) ax.set_title(title, fontsize=12) def main(base_path): if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')): os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')) trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0) chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0) # trial_summary = trial_summary[chirp_notes['good'] == 1] trial_mask = chirp_notes['good'] == 1 all_marcov_matrix = [] all_event_counts = [] all_agonistic_categorie = [] all_chase_durs = [] # agonistic categorie plot fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54)) gs = gridspec.GridSpec(2, 1, left=0.1, bottom=0.1, right=0.9, top=0.95, height_ratios=[1, 4], hspace=0) ax = fig.add_subplot(gs[1, 0]) ax_spec = fig.add_subplot(gs[0, 0], sharex=ax) plt.setp(ax_spec.get_xticklabels(), visible=False) for i in range(1, 5): ax.fill_between([0, 4], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey') ax.fill_between([5, 10], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey') fill_dots = np.arange(4, 5.1, 0.125) ax.plot(fill_dots, np.ones_like(fill_dots)*i, '.', color='tab:grey', markersize=3) got_examples = [False, False, False, False] example_skips = [3, 4, 3, 0] for index, trial in trial_summary.iterrows(): trial_path = os.path.join(base_path, trial['recording']) if not trial_mask[index]: continue if trial['group'] < 5: continue if not os.path.exists(os.path.join(trial_path, 'led_idxs.csv')): continue if not os.path.exists(os.path.join(trial_path, 'LED_frames.npy')): continue if trial['draw'] == 1: continue ids = np.load(os.path.join(trial_path, 'analysis', 'ids.npy')) times = np.load(os.path.join(trial_path, 'times.npy')) sorter = -1 if trial['win_ID'] != ids[0] else 1 ### event times --> BORIS behavior contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames = \ load_and_converete_boris_events(trial_path, trial['recording'], sr=20_000) ### communication if not os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')): continue chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy')) chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy')) chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]] rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter] rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))] rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]] event_times = [] event_labels = [] loop_times = [chirp_times[1], rise_times[1], chirp_times[0], rise_times[0], ag_on_off_t_GRID[:, 0], ag_on_off_t_GRID[:, 1], contact_t_GRID] loop_labels = [r'chirp$_{lose}$', r'rise$_{lose}$', r'chirp$_{win}$', r'rise$_{win}$', r'chace$_{on}$', r'chace$_{off}$', 'contact'] event_counts = np.array([len(chirp_times[1]), len(rise_times[1]), len(chirp_times[0]), len(rise_times[0]), len(ag_on_off_t_GRID), len(ag_on_off_t_GRID), len(contact_t_GRID)]) for ll, t in zip(loop_labels, loop_times): event_times.extend(t) event_labels.extend(np.full(len(t), ll)) time_sorter = np.argsort(event_times) event_times = np.array(event_times)[time_sorter] event_labels = np.array(event_labels)[time_sorter] ### create marcov_matrix for one trial !!! marcov_matrix = np.zeros((len(loop_labels)+1, len(loop_labels)+1)) for enu_ori, label_ori in enumerate(loop_labels): for enu_tar, label_tar in enumerate(loop_labels): n = len(event_times[:-1][(event_labels[:-1] == label_ori) & (event_labels[1:] == label_tar) & (np.diff(event_times) <= 5)]) marcov_matrix[enu_ori, enu_tar] = n for enu_tar, label_tar in enumerate(loop_labels): n = len(event_times[:-1][(event_labels[1:] == label_tar) & (np.diff(event_times) > 5)]) marcov_matrix[-1, enu_tar] = n marcov_matrix[-1, 5] = 0 loop_labels.append('void') event_counts = np.append(event_counts, marcov_matrix[-1].sum()) ### get those cases where ag_on does not point to event and no event points to corresponding ag_off ... add thise cases in marcov matrix chase_on_idx = np.where(event_labels == loop_labels[4])[0] chase_off_idx = np.where(event_labels == loop_labels[5])[0] helper_mask = np.ones_like(chase_on_idx) helper_mask[np.diff(event_times)[chase_on_idx] <= 5] = 0 helper_mask[np.diff(event_times)[chase_off_idx-1] <= 5] = 0 marcov_matrix[4, 5] += np.sum(helper_mask) all_marcov_matrix.append(marcov_matrix) all_event_counts.append(event_counts) # plot_transition_matrix(marcov_matrix, loop_labels) # plot_transition_diagram(marcov_matrix, loop_labels, node_size=event_counts) # plot_transition_diagram(marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100, loop_labels, node_size=event_counts)+ agonitic_categorie = np.zeros(len(ag_on_off_t_GRID)) chase_durs = np.zeros_like(agonitic_categorie) all_agonistic_categorie.append(agonitic_categorie) all_chase_durs.append(chase_durs) ### plotting for enu, (chase_on_time, chase_off_time) in enumerate(ag_on_off_t_GRID): chase_dur = chase_off_time - chase_on_time chase_durs[enu] = chase_dur chirp_dt = chase_dur if chase_dur < 5 else 5 max_dt = 5 rise_before, chirp_arround_end = False, False if np.any(((chase_on_time - rise_times[1]) > 0) & ((chase_on_time - rise_times[1]) < max_dt)): rise_times_oi = rise_times[1][((chase_on_time - rise_times[1]) > 0) & ((chase_on_time - rise_times[1]) < max_dt)] rise_before = True if np.any( ((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)): # chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)] chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chase_dur) & ((chirp_times[1] - chase_off_time) < max_dt)] chirp_arround_end = True if rise_before: if chirp_arround_end: agonitic_categorie[enu] = 1 else: agonitic_categorie[enu] = 2 else: if chirp_arround_end: agonitic_categorie[enu] = 3 else: agonitic_categorie[enu] = 4 if agonitic_categorie[enu] == 1 and not got_examples[0]: if chase_dur > 10: if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0): if example_skips[int(agonitic_categorie[enu] - 1)] == 0: for ct in chirp_time_oi: ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10], [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='k', lw=2) for rt in rise_times_oi: ax.plot([rt - chase_on_time, rt - chase_on_time], [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='firebrick', lw=2) got_examples[0] = True else: example_skips[int(agonitic_categorie[enu] - 1)] -= 1 elif agonitic_categorie[enu] == 2 and not got_examples[1]: if chase_dur > 10: if example_skips[int(agonitic_categorie[enu] - 1)] == 0: for rt in rise_times_oi: ax.plot([rt - chase_on_time, rt - chase_on_time], [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='firebrick', lw=2) got_examples[1] = True else: example_skips[int(agonitic_categorie[enu] - 1)] -= 1 elif agonitic_categorie[enu] == 3 and not got_examples[2]: if chase_dur > 10: if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0): if example_skips[int(agonitic_categorie[enu] - 1)] == 0: for ct in chirp_time_oi: ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10], [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='k', lw=2) got_examples[2] = True else: example_skips[int(agonitic_categorie[enu] - 1)] -= 1 else: pass ### agonistic categories stacked_agonistic_categories = np.hstack(all_agonistic_categorie) stacked_all_chase_durs = np.hstack(all_chase_durs) pct_each_categorie = np.zeros(4) for enu, cat in enumerate(range(1, 5)): pct_each_categorie[enu] = len(stacked_agonistic_categories[stacked_agonistic_categories == cat]) / len(stacked_agonistic_categories) # example plot for enu, cat_pct in enumerate(pct_each_categorie): ax.text(15.2, enu+1, f'{cat_pct*100:.1f}' + ' $\%$', clip_on=False, fontsize=14, ha='left', va='center') ax.plot([0, 0], [0.8, 5], '--', color='k', lw=1) ax.plot([10, 10], [0.8, 5], '--', color='k', lw=1) ax.set_ylim(0.25, 4.5) ax.set_xlim(-5, 15) ax.set_yticks([1, 2, 3, 4]) # ax.set_yticklabels([r'rise$_{pre}$ $&$ chirp$_{end}$', r'only rise$_{pre}$', r'only chirp$_{end}$', 'no communication']) ax.set_yticklabels(['A ', 'B ', 'C ', 'D ']) ax.invert_yaxis() ax.set_xlabel('time [s]', fontsize=12) ax.tick_params(axis='y', labelsize=20) ax.tick_params(axis = 'x', labelsize=10) legend_elements = [Line2D([0], [0], color='firebrick', lw=2, label=r'rise$_{lose}$'), Line2D([0], [0], color='k', lw=2, label=r'chirp$_{lose}$'), Patch(facecolor='tab:grey', edgecolor='w', label= 'chase event')] ax.legend(handles=legend_elements, loc='upper right', ncol=3, bbox_to_anchor=(1, 1), frameon=False, fontsize=10, facecolor='white') ax.spines[['right', 'top']].set_visible(False) plt.show() # bar plot fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54)) ax.bar(np.arange(4), [len(stacked_agonistic_categories[stacked_agonistic_categories == 1]), len(stacked_agonistic_categories[stacked_agonistic_categories == 2]), len(stacked_agonistic_categories[stacked_agonistic_categories == 3]), len(stacked_agonistic_categories[stacked_agonistic_categories == 4])]) ax.set_xticks(np.arange(4)) ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _']) plt.show() # pct pct_agon_categorie = np.zeros(shape=(len(all_agonistic_categorie), 4)) for enu, agonitic_categorie in enumerate(all_agonistic_categorie): for cat in np.arange(4): pct_agon_categorie[enu, cat] = len(agonitic_categorie[agonitic_categorie == cat+1]) / len(agonitic_categorie) fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54)) ax.bar(np.arange(4), pct_agon_categorie.mean(0)) ax.errorbar(np.arange(4), pct_agon_categorie.mean(0), yerr=pct_agon_categorie.std(0), fmt='', color='k', linestyle='None') ax.set_xticks(np.arange(4)) ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _']) plt.show() ### marcov models all_marcov_matrix = np.array(all_marcov_matrix) all_event_counts = np.array(all_event_counts) collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0) collective_event_counts = np.sum(all_event_counts, axis=0) plot_transition_matrix(collective_marcov_matrix, loop_labels) fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) plot_transition_diagram( collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]') plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300) plt.close() fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100, loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]') plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300) plt.close() for i, (marcov_matrix, event_counts) in enumerate(zip(all_marcov_matrix, all_event_counts)): fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) plot_transition_diagram( marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100, loop_labels, event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]') plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'), dpi=300) plt.close() fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) plot_transition_diagram(marcov_matrix / event_counts * 100, loop_labels, event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]') plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'), dpi=300) plt.close() embed() quit() pass if __name__ == '__main__': main(sys.argv[1])