savety
This commit is contained in:
parent
988d80f7bd
commit
3a4ece9902
93
ethogram.py
93
ethogram.py
@ -2,6 +2,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.gridspec as gridspec
|
import matplotlib.gridspec as gridspec
|
||||||
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import scipy.stats as scp
|
import scipy.stats as scp
|
||||||
@ -13,21 +14,36 @@ glob_colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C',
|
|||||||
|
|
||||||
|
|
||||||
def plot_transition_matrix(matrix, labels):
|
def plot_transition_matrix(matrix, labels):
|
||||||
fig, ax = plt.subplots()
|
fig = plt.figure(figsize=(20/2.54, 20/2.54))
|
||||||
|
#gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.1, right=0.9, top=0.95, wspace=0.1, width_ratios=[8, 1])
|
||||||
|
gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.925, top=0.95)
|
||||||
|
ax = fig.add_subplot(gs[0, 0])
|
||||||
|
|
||||||
|
divider = make_axes_locatable(ax)
|
||||||
|
cax = divider.append_axes('right', size='5%', pad=0.05)
|
||||||
|
|
||||||
|
# cax = fig.add_subplot(gs[0, 1])
|
||||||
im = ax.imshow(matrix)
|
im = ax.imshow(matrix)
|
||||||
ax.set_xticks(list(range(len(matrix))))
|
ax.set_xticks(list(range(len(matrix))))
|
||||||
ax.set_yticks(list(range(len(matrix))))
|
ax.set_yticks(list(range(len(matrix))))
|
||||||
ax.set_xticklabels(labels)
|
ax.set_xticklabels(labels, rotation=45)
|
||||||
ax.set_yticklabels(labels)
|
ax.set_yticklabels(labels)
|
||||||
fig.colorbar(im)
|
|
||||||
plt.show()
|
fig.colorbar(im, cax=cax, orientation='vertical')
|
||||||
|
|
||||||
|
ax.tick_params(labelsize=10)
|
||||||
|
cax.tick_params(labelsize=10)
|
||||||
|
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'event_counts' + '.png'), dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_transition_diagram(matrix, labels, node_size, ax, threshold=5,
|
||||||
|
color_by_origin=False, color_by_target=False, title=''):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = 'rdm'):
|
|
||||||
matrix[matrix <= threshold] = 0
|
matrix[matrix <= threshold] = 0
|
||||||
matrix = np.around(matrix, decimals=1)
|
matrix = np.around(matrix, decimals=1)
|
||||||
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
|
||||||
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
|
||||||
Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
||||||
|
|
||||||
node_labels = dict(zip(Graph, labels))
|
node_labels = dict(zip(Graph, labels))
|
||||||
@ -42,24 +58,24 @@ def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = '
|
|||||||
|
|
||||||
# ToDo: nodes
|
# ToDo: nodes
|
||||||
nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)])
|
nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)])
|
||||||
nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels)
|
nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels, ax=ax)
|
||||||
# google networkx drawing to get better graphs with networkx
|
# google networkx drawing to get better graphs with networkx
|
||||||
# nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax)
|
# nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax)
|
||||||
# # ToDo: edges
|
# # ToDo: edges
|
||||||
edge_width = np.array([x / 5 for x in [*edge_labels.values()]])
|
edge_width = np.array([x / 5 for x in [*edge_labels.values()]])
|
||||||
|
if color_by_origin:
|
||||||
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]]
|
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]]
|
||||||
|
elif color_by_target:
|
||||||
|
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 1]]
|
||||||
|
else:
|
||||||
|
edge_colors = 'k'
|
||||||
|
|
||||||
|
|
||||||
edge_width[edge_width >= 6] = 6
|
edge_width[edge_width >= 6] = 6
|
||||||
# nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
|
|
||||||
# arrows=True, arrowsize=20,
|
|
||||||
# min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0",
|
|
||||||
# ax=ax)
|
|
||||||
# nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.5, edge_labels=edge_labels, ax=ax, rotate=False)
|
|
||||||
|
|
||||||
nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
|
nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
|
||||||
arrows=True, arrowsize=20,
|
arrows=True, arrowsize=20,
|
||||||
min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.025",
|
min_target_margin=25, min_source_margin=25, connectionstyle="arc3, rad=0.025",
|
||||||
ax=ax, edge_color=edge_colors)
|
ax=ax, edge_color=edge_colors)
|
||||||
nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True)
|
nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True)
|
||||||
|
|
||||||
@ -70,11 +86,11 @@ def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = '
|
|||||||
|
|
||||||
ax.set_xlim(-1.3, 1.3)
|
ax.set_xlim(-1.3, 1.3)
|
||||||
ax.set_ylim(-1.3, 1.3)
|
ax.set_ylim(-1.3, 1.3)
|
||||||
# plt.title(title)
|
ax.set_title(title, fontsize=12)
|
||||||
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', save_str + '.png'), dpi=300)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
def main(base_path):
|
def main(base_path):
|
||||||
|
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')):
|
||||||
|
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov'))
|
||||||
|
|
||||||
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
|
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
|
||||||
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
||||||
# trial_summary = trial_summary[chirp_notes['good'] == 1]
|
# trial_summary = trial_summary[chirp_notes['good'] == 1]
|
||||||
@ -130,7 +146,9 @@ def main(base_path):
|
|||||||
event_times = np.array(event_times)[time_sorter]
|
event_times = np.array(event_times)[time_sorter]
|
||||||
event_labels = np.array(event_labels)[time_sorter]
|
event_labels = np.array(event_labels)[time_sorter]
|
||||||
|
|
||||||
|
### create marcov_matrix 1: which beh 2 is triggered by beh. 1 ?
|
||||||
marcov_matrix = np.zeros((len(loop_labels)+1, len(loop_labels)+1))
|
marcov_matrix = np.zeros((len(loop_labels)+1, len(loop_labels)+1))
|
||||||
|
### create marcov_matrix 2: beh 2 is triggered by which beh. 1 ?
|
||||||
|
|
||||||
for enu_ori, label_ori in enumerate(loop_labels):
|
for enu_ori, label_ori in enumerate(loop_labels):
|
||||||
for enu_tar, label_tar in enumerate(loop_labels):
|
for enu_tar, label_tar in enumerate(loop_labels):
|
||||||
@ -164,14 +182,43 @@ def main(base_path):
|
|||||||
collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0)
|
collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0)
|
||||||
collective_event_counts = np.sum(all_event_counts, axis=0)
|
collective_event_counts = np.sum(all_event_counts, axis=0)
|
||||||
|
|
||||||
|
|
||||||
plot_transition_matrix(collective_marcov_matrix, loop_labels)
|
plot_transition_matrix(collective_marcov_matrix, loop_labels)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||||
|
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||||
|
|
||||||
plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100,
|
plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100,
|
||||||
loop_labels, collective_event_counts, threshold=5, save_str='markov_all')
|
loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]')
|
||||||
|
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||||
|
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||||
|
plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100,
|
||||||
|
loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]')
|
||||||
|
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
for i, (marcov_matrix, event_counts) in enumerate(zip(all_marcov_matrix, all_event_counts)):
|
||||||
|
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||||
|
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||||
|
|
||||||
|
plot_transition_diagram(
|
||||||
|
marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100,
|
||||||
|
loop_labels, event_counts, ax, threshold=5, color_by_origin=True,
|
||||||
|
title='origin triggers target [%]')
|
||||||
|
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'), dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||||
|
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||||
|
plot_transition_diagram(marcov_matrix / event_counts * 100,
|
||||||
|
loop_labels, event_counts, ax, threshold=5, color_by_target=True,
|
||||||
|
title='target triggered by origin [%]')
|
||||||
|
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'), dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
# for i in range(len(all_marcov_matrix)):
|
|
||||||
# plot_transition_diagram(
|
|
||||||
# all_marcov_matrix[i] / all_event_counts[i].reshape(len(all_event_counts[i]), 1) * 100,
|
|
||||||
# loop_labels, all_event_counts[i], threshold=5)
|
|
||||||
embed()
|
embed()
|
||||||
quit()
|
quit()
|
||||||
pass
|
pass
|
||||||
|
@ -13,9 +13,23 @@ from event_time_correlations import load_and_converete_boris_events, kde, gauss
|
|||||||
female_color, male_color = '#e74c3c', '#3498db'
|
female_color, male_color = '#e74c3c', '#3498db'
|
||||||
|
|
||||||
def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''):
|
def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''):
|
||||||
|
# ToDo: finish this !!!
|
||||||
iei = []
|
iei = []
|
||||||
|
weighted_mean_iei = []
|
||||||
|
median_iei = []
|
||||||
for i in range(len(event_times)):
|
for i in range(len(event_times)):
|
||||||
iei.append(np.diff(event_times[i]))
|
trial_iei = np.diff(event_times[i][event_times[i] <= 3600*3])
|
||||||
|
iei.append(trial_iei)
|
||||||
|
|
||||||
|
if len(trial_iei) == 0:
|
||||||
|
weighted_mean_iei.append(np.nan)
|
||||||
|
median_iei.append(np.nan)
|
||||||
|
else:
|
||||||
|
weighted_mean_iei.append(np.sum((trial_iei) * trial_iei) / np.sum(trial_iei))
|
||||||
|
median_iei.append(np.median(trial_iei))
|
||||||
|
|
||||||
|
weighted_mean_iei = np.array(weighted_mean_iei)
|
||||||
|
median_iei = np.array(median_iei)
|
||||||
|
|
||||||
fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
|
fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
|
||||||
gs = gridspec.GridSpec(2, 2, left=0.1, bottom=0.1, right=0.95, top=0.9)
|
gs = gridspec.GridSpec(2, 2, left=0.1, bottom=0.1, right=0.95, top=0.9)
|
||||||
@ -41,13 +55,16 @@ def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''):
|
|||||||
color, linestyle = female_color, '-'
|
color, linestyle = female_color, '-'
|
||||||
sp = 3
|
sp = 3
|
||||||
|
|
||||||
|
|
||||||
conv_y = np.arange(0, np.percentile(np.hstack(iei), 80), .5)
|
conv_y = np.arange(0, np.percentile(np.hstack(iei), 80), .5)
|
||||||
kde_array = kde(iei[i], conv_y, kernal_w=kernal_w, kernal_h=1)
|
kde_array = kde(iei[i], conv_y, kernal_w=kernal_w, kernal_h=1)
|
||||||
|
|
||||||
# kde_array /= np.sum(kde_array)
|
# kde_array /= np.sum(kde_array)
|
||||||
ax[sp].plot(conv_y, kde_array, zorder=2, color=color, linestyle=linestyle, lw=2)
|
ax[sp].plot(conv_y, kde_array, zorder=2, color=color, linestyle=linestyle, lw=2)
|
||||||
|
|
||||||
|
# ax_m = ax[0].twinx()
|
||||||
|
# ax_m.boxplot([weighted_mean_iei[(win_sex == 'm') & (win_sex == 'm') & ~np.isnan(weighted_mean_iei)],
|
||||||
|
# median_iei[(win_sex == 'm') & (win_sex == 'm') & ~np.isnan(median_iei)]], sym='', vert=False)
|
||||||
|
|
||||||
ax[0].set_xlim(conv_y[0], conv_y[-1])
|
ax[0].set_xlim(conv_y[0], conv_y[-1])
|
||||||
ax[0].set_ylabel('KDE', fontsize=12)
|
ax[0].set_ylabel('KDE', fontsize=12)
|
||||||
ax[2].set_ylabel('KDE', fontsize=12)
|
ax[2].set_ylabel('KDE', fontsize=12)
|
||||||
@ -170,7 +187,6 @@ def relative_rate_progression(all_event_t, title=''):
|
|||||||
|
|
||||||
|
|
||||||
def main(base_path):
|
def main(base_path):
|
||||||
# ToDo: for chirp and rise analysis different datasets!!!
|
|
||||||
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta')):
|
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta')):
|
||||||
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta'))
|
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta'))
|
||||||
|
|
||||||
@ -178,7 +194,6 @@ def main(base_path):
|
|||||||
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr'))
|
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
|
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
|
||||||
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
||||||
trial_mask = chirp_notes['good'] == 1
|
trial_mask = chirp_notes['good'] == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user