From e22108cdd758c9cacaf995f57e25c9c8d72801e9 Mon Sep 17 00:00:00 2001
From: Till Raab <till.raab@uni-tuebingen.de>
Date: Fri, 16 Jun 2023 10:52:09 +0200
Subject: [PATCH] add void to ethogram. ethogram is now colorful :D

---
 ethogram.py | 36 +++++++++++++++++++++++-------------
 1 file changed, 23 insertions(+), 13 deletions(-)

diff --git a/ethogram.py b/ethogram.py
index 3f3e59c..77e79a5 100644
--- a/ethogram.py
+++ b/ethogram.py
@@ -9,6 +9,8 @@ import networkx as nx
 
 from IPython import embed
 from event_time_correlations import load_and_converete_boris_events
+glob_colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C', '#1E90FF']
+
 
 def plot_transition_matrix(matrix, labels):
     fig, ax = plt.subplots()
@@ -21,9 +23,8 @@ def plot_transition_matrix(matrix, labels):
     plt.show()
 
 
-def plot_transition_diagram(matrix, labels, node_size):
-
-    matrix[matrix <= 5] = 0
+def plot_transition_diagram(matrix, labels, node_size, threshold=5):
+    matrix[matrix <= threshold] = 0
     matrix = np.around(matrix, decimals=1)
     fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
     fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
@@ -36,16 +37,20 @@ def plot_transition_diagram(matrix, labels, node_size):
     positions = nx.circular_layout(Graph)
     positions2 = nx.circular_layout(Graph)
     for p in positions:
-        positions2[p][0] -= .1
-        positions2[p][1] -= .1
+        positions2[p][0] *= 1.2
+        positions2[p][1] *= 1.2
 
     # ToDo: nodes
-    nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size*2, ax=ax, alpha=0.5)
-    nx.draw_networkx_labels(Graph, pos=positions, labels=node_labels)
+    nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)])
+    nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels)
     # google networkx drawing to get better graphs with networkx
     # nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax)
     # # ToDo: edges
-    edge_width = [x / 10 for x in [*edge_labels.values()]]
+    edge_width = np.array([x / 5 for x in [*edge_labels.values()]])
+    edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]]
+
+
+    edge_width[edge_width >= 6] = 6
     # nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
     #                        arrows=True, arrowsize=20,
     #                        min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0",
@@ -54,17 +59,18 @@ def plot_transition_diagram(matrix, labels, node_size):
 
     nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
                            arrows=True, arrowsize=20,
-                           min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0",
-                           ax=ax)
-    nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=False)
+                           min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.025",
+                           ax=ax, edge_color=edge_colors)
+    nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True)
 
     ax.spines["top"].set_visible(False)
     ax.spines["bottom"].set_visible(False)
     ax.spines["right"].set_visible(False)
     ax.spines["left"].set_visible(False)
 
+    ax.set_xlim(-1.3, 1.3)
+    ax.set_ylim(-1.3, 1.3)
     # plt.title(title)
-
     plt.show()
 
 def main(base_path):
@@ -153,8 +159,12 @@ def main(base_path):
     collective_event_counts = np.sum(all_event_counts, axis=0)
 
     plot_transition_matrix(collective_marcov_matrix, loop_labels)
-    plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, loop_labels, collective_event_counts)
+    plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, loop_labels, collective_event_counts, threshold=5)
 
+    # for i in range(len(all_marcov_matrix)):
+    #     plot_transition_diagram(
+    #         all_marcov_matrix[i] / all_event_counts[i].reshape(len(all_event_counts[i]), 1) * 100,
+    #         loop_labels, all_event_counts[i], threshold=5)
     embed()
     quit()
     pass