diff --git a/ethogram.py b/ethogram.py new file mode 100644 index 0000000..3f3e59c --- /dev/null +++ b/ethogram.py @@ -0,0 +1,164 @@ +import os +import sys +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import numpy as np +import pandas as pd +import scipy.stats as scp +import networkx as nx + +from IPython import embed +from event_time_correlations import load_and_converete_boris_events + +def plot_transition_matrix(matrix, labels): + fig, ax = plt.subplots() + im = ax.imshow(matrix) + ax.set_xticks(list(range(len(matrix)))) + ax.set_yticks(list(range(len(matrix)))) + ax.set_xticklabels(labels) + ax.set_yticklabels(labels) + fig.colorbar(im) + plt.show() + + +def plot_transition_diagram(matrix, labels, node_size): + + matrix[matrix <= 5] = 0 + matrix = np.around(matrix, decimals=1) + fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) + fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) + Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + + node_labels = dict(zip(Graph, labels)) + # Graph = nx.relabel_nodes(Graph, node_labels) + + edge_labels = nx.get_edge_attributes(Graph, 'weight') + positions = nx.circular_layout(Graph) + positions2 = nx.circular_layout(Graph) + for p in positions: + positions2[p][0] -= .1 + positions2[p][1] -= .1 + + # ToDo: nodes + nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size*2, ax=ax, alpha=0.5) + nx.draw_networkx_labels(Graph, pos=positions, labels=node_labels) + # google networkx drawing to get better graphs with networkx + # nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax) + # # ToDo: edges + edge_width = [x / 10 for x in [*edge_labels.values()]] + # nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width, + # arrows=True, arrowsize=20, + # min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0", + # ax=ax) + # nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.5, edge_labels=edge_labels, ax=ax, rotate=False) + + nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width, + arrows=True, arrowsize=20, + min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0", + ax=ax) + nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=False) + + ax.spines["top"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["left"].set_visible(False) + + # plt.title(title) + + plt.show() + +def main(base_path): + trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0) + chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0) + # trial_summary = trial_summary[chirp_notes['good'] == 1] + trial_mask = chirp_notes['good'] == 1 + + all_marcov_matrix = [] + all_event_counts = [] + for index, trial in trial_summary.iterrows(): + trial_path = os.path.join(base_path, trial['recording']) + + if not trial_mask[index]: + continue + if trial['group'] < 5: + continue + if not os.path.exists(os.path.join(trial_path, 'led_idxs.csv')): + continue + if not os.path.exists(os.path.join(trial_path, 'LED_frames.npy')): + continue + if trial['draw'] == 1: + continue + + ids = np.load(os.path.join(trial_path, 'analysis', 'ids.npy')) + times = np.load(os.path.join(trial_path, 'times.npy')) + sorter = -1 if trial['win_ID'] != ids[0] else 1 + + ### event times --> BORIS behavior + contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames = \ + load_and_converete_boris_events(trial_path, trial['recording'], sr=20_000) + + ### communication + if not os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')): + continue + + chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy')) + chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy')) + chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]] + + rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter] + rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))] + rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]] + + event_times = [] + event_labels = [] + loop_times = [chirp_times[1], rise_times[1], chirp_times[0], rise_times[0], ag_on_off_t_GRID[:, 0], + ag_on_off_t_GRID[:, 1], contact_t_GRID] + loop_labels = [r'chirp$_{lose}$', r'rise$_{lose}$', r'chirp$_{win}$', r'rise$_{win}$', r'chace$_{on}$', r'chace$_{off}$', 'contact'] + event_counts = np.array([len(chirp_times[1]), len(rise_times[1]), len(chirp_times[0]), len(rise_times[0]), len(ag_on_off_t_GRID), len(ag_on_off_t_GRID), len(contact_t_GRID)]) + for ll, t in zip(loop_labels, loop_times): + event_times.extend(t) + event_labels.extend(np.full(len(t), ll)) + + time_sorter = np.argsort(event_times) + event_times = np.array(event_times)[time_sorter] + event_labels = np.array(event_labels)[time_sorter] + + marcov_matrix = np.zeros((len(loop_labels), len(loop_labels))) + + for enu_ori, label_ori in enumerate(loop_labels): + for enu_tar, label_tar in enumerate(loop_labels): + n = len(event_times[:-1][(event_labels[:-1] == label_ori) & (event_labels[1:] == label_tar) & (np.diff(event_times) <= 5)]) + marcov_matrix[enu_ori, enu_tar] = n + + + ### get those cases where ag_on does not point to event and no event points to corresponding ag_off ... add thise cases in marcov matrix + chase_on_idx = np.where(event_labels == loop_labels[4])[0] + chase_off_idx = np.where(event_labels == loop_labels[5])[0] + helper_mask = np.ones_like(chase_on_idx) + helper_mask[np.diff(event_times)[chase_on_idx] <= 5] = 0 + helper_mask[np.diff(event_times)[chase_off_idx-1] <= 5] = 0 + + marcov_matrix[4, 5] += np.sum(helper_mask) + + all_marcov_matrix.append(marcov_matrix) + all_event_counts.append(event_counts) + # plot_transition_matrix(marcov_matrix, loop_labels) + + # plot_transition_diagram(marcov_matrix, loop_labels, node_size=event_counts) + # plot_transition_diagram(marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100, loop_labels, node_size=event_counts) + all_marcov_matrix = np.array(all_marcov_matrix) + all_event_counts = np.array(all_event_counts) + + collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0) + collective_event_counts = np.sum(all_event_counts, axis=0) + + plot_transition_matrix(collective_marcov_matrix, loop_labels) + plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, loop_labels, collective_event_counts) + + embed() + quit() + pass + + +if __name__ == '__main__': + main(sys.argv[1]) diff --git a/event_time_correlations.py b/event_time_correlations.py index 9c7404b..b80017c 100644 --- a/event_time_correlations.py +++ b/event_time_correlations.py @@ -240,15 +240,17 @@ def single_kde(event_dt, conv_t, kernal_w = 1, kernal_h = 0.2): return cp.asnumpy(single_kdes) + def main(base_path): if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr')): os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr')) - trial_summary = pd.read_csv('trial_summary.csv', index_col=0) + trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0) chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0) # trial_summary = trial_summary[chirp_notes['good'] == 1] trial_mask = chirp_notes['good'] == 1 + # ToDo: do chirp on chirp and rise on rise lose_chrips_centered_on_ag_off_t = [] lose_chrips_centered_on_ag_on_t = [] lose_chrips_centered_on_contact_t = []