make chirp-chrip cross correlation. marcov model works but requires refinement. arrows from each node should sum up to 100pct ?!

This commit is contained in:
Till Raab 2023-06-14 15:05:49 +02:00
parent 8af8c92b27
commit 725c0b426a
2 changed files with 167 additions and 1 deletions

164
ethogram.py Normal file
View File

@ -0,0 +1,164 @@
import os
import sys
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import scipy.stats as scp
import networkx as nx
from IPython import embed
from event_time_correlations import load_and_converete_boris_events
def plot_transition_matrix(matrix, labels):
fig, ax = plt.subplots()
im = ax.imshow(matrix)
ax.set_xticks(list(range(len(matrix))))
ax.set_yticks(list(range(len(matrix))))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
fig.colorbar(im)
plt.show()
def plot_transition_diagram(matrix, labels, node_size):
matrix[matrix <= 5] = 0
matrix = np.around(matrix, decimals=1)
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
node_labels = dict(zip(Graph, labels))
# Graph = nx.relabel_nodes(Graph, node_labels)
edge_labels = nx.get_edge_attributes(Graph, 'weight')
positions = nx.circular_layout(Graph)
positions2 = nx.circular_layout(Graph)
for p in positions:
positions2[p][0] -= .1
positions2[p][1] -= .1
# ToDo: nodes
nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size*2, ax=ax, alpha=0.5)
nx.draw_networkx_labels(Graph, pos=positions, labels=node_labels)
# google networkx drawing to get better graphs with networkx
# nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax)
# # ToDo: edges
edge_width = [x / 10 for x in [*edge_labels.values()]]
# nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
# arrows=True, arrowsize=20,
# min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0",
# ax=ax)
# nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.5, edge_labels=edge_labels, ax=ax, rotate=False)
nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
arrows=True, arrowsize=20,
min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0",
ax=ax)
nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=False)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
# plt.title(title)
plt.show()
def main(base_path):
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
# trial_summary = trial_summary[chirp_notes['good'] == 1]
trial_mask = chirp_notes['good'] == 1
all_marcov_matrix = []
all_event_counts = []
for index, trial in trial_summary.iterrows():
trial_path = os.path.join(base_path, trial['recording'])
if not trial_mask[index]:
continue
if trial['group'] < 5:
continue
if not os.path.exists(os.path.join(trial_path, 'led_idxs.csv')):
continue
if not os.path.exists(os.path.join(trial_path, 'LED_frames.npy')):
continue
if trial['draw'] == 1:
continue
ids = np.load(os.path.join(trial_path, 'analysis', 'ids.npy'))
times = np.load(os.path.join(trial_path, 'times.npy'))
sorter = -1 if trial['win_ID'] != ids[0] else 1
### event times --> BORIS behavior
contact_t_GRID, ag_on_off_t_GRID, led_idx, led_frames = \
load_and_converete_boris_events(trial_path, trial['recording'], sr=20_000)
### communication
if not os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')):
continue
chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy'))
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]]
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]]
event_times = []
event_labels = []
loop_times = [chirp_times[1], rise_times[1], chirp_times[0], rise_times[0], ag_on_off_t_GRID[:, 0],
ag_on_off_t_GRID[:, 1], contact_t_GRID]
loop_labels = [r'chirp$_{lose}$', r'rise$_{lose}$', r'chirp$_{win}$', r'rise$_{win}$', r'chace$_{on}$', r'chace$_{off}$', 'contact']
event_counts = np.array([len(chirp_times[1]), len(rise_times[1]), len(chirp_times[0]), len(rise_times[0]), len(ag_on_off_t_GRID), len(ag_on_off_t_GRID), len(contact_t_GRID)])
for ll, t in zip(loop_labels, loop_times):
event_times.extend(t)
event_labels.extend(np.full(len(t), ll))
time_sorter = np.argsort(event_times)
event_times = np.array(event_times)[time_sorter]
event_labels = np.array(event_labels)[time_sorter]
marcov_matrix = np.zeros((len(loop_labels), len(loop_labels)))
for enu_ori, label_ori in enumerate(loop_labels):
for enu_tar, label_tar in enumerate(loop_labels):
n = len(event_times[:-1][(event_labels[:-1] == label_ori) & (event_labels[1:] == label_tar) & (np.diff(event_times) <= 5)])
marcov_matrix[enu_ori, enu_tar] = n
### get those cases where ag_on does not point to event and no event points to corresponding ag_off ... add thise cases in marcov matrix
chase_on_idx = np.where(event_labels == loop_labels[4])[0]
chase_off_idx = np.where(event_labels == loop_labels[5])[0]
helper_mask = np.ones_like(chase_on_idx)
helper_mask[np.diff(event_times)[chase_on_idx] <= 5] = 0
helper_mask[np.diff(event_times)[chase_off_idx-1] <= 5] = 0
marcov_matrix[4, 5] += np.sum(helper_mask)
all_marcov_matrix.append(marcov_matrix)
all_event_counts.append(event_counts)
# plot_transition_matrix(marcov_matrix, loop_labels)
# plot_transition_diagram(marcov_matrix, loop_labels, node_size=event_counts)
# plot_transition_diagram(marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100, loop_labels, node_size=event_counts)
all_marcov_matrix = np.array(all_marcov_matrix)
all_event_counts = np.array(all_event_counts)
collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0)
collective_event_counts = np.sum(all_event_counts, axis=0)
plot_transition_matrix(collective_marcov_matrix, loop_labels)
plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, loop_labels, collective_event_counts)
embed()
quit()
pass
if __name__ == '__main__':
main(sys.argv[1])

View File

@ -240,15 +240,17 @@ def single_kde(event_dt, conv_t, kernal_w = 1, kernal_h = 0.2):
return cp.asnumpy(single_kdes)
def main(base_path):
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr'))
trial_summary = pd.read_csv('trial_summary.csv', index_col=0)
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
# trial_summary = trial_summary[chirp_notes['good'] == 1]
trial_mask = chirp_notes['good'] == 1
# ToDo: do chirp on chirp and rise on rise
lose_chrips_centered_on_ag_off_t = []
lose_chrips_centered_on_ag_on_t = []
lose_chrips_centered_on_contact_t = []