diff --git a/ethogram.py b/ethogram.py index 3f3e59c..77e79a5 100644 --- a/ethogram.py +++ b/ethogram.py @@ -9,6 +9,8 @@ import networkx as nx from IPython import embed from event_time_correlations import load_and_converete_boris_events +glob_colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C', '#1E90FF'] + def plot_transition_matrix(matrix, labels): fig, ax = plt.subplots() @@ -21,9 +23,8 @@ def plot_transition_matrix(matrix, labels): plt.show() -def plot_transition_diagram(matrix, labels, node_size): - - matrix[matrix <= 5] = 0 +def plot_transition_diagram(matrix, labels, node_size, threshold=5): + matrix[matrix <= threshold] = 0 matrix = np.around(matrix, decimals=1) fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) @@ -36,16 +37,20 @@ def plot_transition_diagram(matrix, labels, node_size): positions = nx.circular_layout(Graph) positions2 = nx.circular_layout(Graph) for p in positions: - positions2[p][0] -= .1 - positions2[p][1] -= .1 + positions2[p][0] *= 1.2 + positions2[p][1] *= 1.2 # ToDo: nodes - nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size*2, ax=ax, alpha=0.5) - nx.draw_networkx_labels(Graph, pos=positions, labels=node_labels) + nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)]) + nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels) # google networkx drawing to get better graphs with networkx # nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax) # # ToDo: edges - edge_width = [x / 10 for x in [*edge_labels.values()]] + edge_width = np.array([x / 5 for x in [*edge_labels.values()]]) + edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]] + + + edge_width[edge_width >= 6] = 6 # nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width, # arrows=True, arrowsize=20, # min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0", @@ -54,17 +59,18 @@ def plot_transition_diagram(matrix, labels, node_size): nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width, arrows=True, arrowsize=20, - min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0", - ax=ax) - nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=False) + min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.025", + ax=ax, edge_color=edge_colors) + nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True) ax.spines["top"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) + ax.set_xlim(-1.3, 1.3) + ax.set_ylim(-1.3, 1.3) # plt.title(title) - plt.show() def main(base_path): @@ -153,8 +159,12 @@ def main(base_path): collective_event_counts = np.sum(all_event_counts, axis=0) plot_transition_matrix(collective_marcov_matrix, loop_labels) - plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, loop_labels, collective_event_counts) + plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, loop_labels, collective_event_counts, threshold=5) + # for i in range(len(all_marcov_matrix)): + # plot_transition_diagram( + # all_marcov_matrix[i] / all_event_counts[i].reshape(len(all_event_counts[i]), 1) * 100, + # loop_labels, all_event_counts[i], threshold=5) embed() quit() pass