savety
This commit is contained in:
parent
988d80f7bd
commit
3a4ece9902
95
ethogram.py
95
ethogram.py
@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.gridspec as gridspec
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy.stats as scp
|
||||
@ -13,21 +14,36 @@ glob_colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C',
|
||||
|
||||
|
||||
def plot_transition_matrix(matrix, labels):
|
||||
fig, ax = plt.subplots()
|
||||
fig = plt.figure(figsize=(20/2.54, 20/2.54))
|
||||
#gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.1, right=0.9, top=0.95, wspace=0.1, width_ratios=[8, 1])
|
||||
gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.925, top=0.95)
|
||||
ax = fig.add_subplot(gs[0, 0])
|
||||
|
||||
divider = make_axes_locatable(ax)
|
||||
cax = divider.append_axes('right', size='5%', pad=0.05)
|
||||
|
||||
# cax = fig.add_subplot(gs[0, 1])
|
||||
im = ax.imshow(matrix)
|
||||
ax.set_xticks(list(range(len(matrix))))
|
||||
ax.set_yticks(list(range(len(matrix))))
|
||||
ax.set_xticklabels(labels)
|
||||
ax.set_xticklabels(labels, rotation=45)
|
||||
ax.set_yticklabels(labels)
|
||||
fig.colorbar(im)
|
||||
plt.show()
|
||||
|
||||
fig.colorbar(im, cax=cax, orientation='vertical')
|
||||
|
||||
ax.tick_params(labelsize=10)
|
||||
cax.tick_params(labelsize=10)
|
||||
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'event_counts' + '.png'), dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_transition_diagram(matrix, labels, node_size, ax, threshold=5,
|
||||
color_by_origin=False, color_by_target=False, title=''):
|
||||
|
||||
|
||||
|
||||
def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = 'rdm'):
|
||||
matrix[matrix <= threshold] = 0
|
||||
matrix = np.around(matrix, decimals=1)
|
||||
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||
Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
||||
|
||||
node_labels = dict(zip(Graph, labels))
|
||||
@ -42,24 +58,24 @@ def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = '
|
||||
|
||||
# ToDo: nodes
|
||||
nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)])
|
||||
nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels)
|
||||
nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels, ax=ax)
|
||||
# google networkx drawing to get better graphs with networkx
|
||||
# nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax)
|
||||
# # ToDo: edges
|
||||
edge_width = np.array([x / 5 for x in [*edge_labels.values()]])
|
||||
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]]
|
||||
if color_by_origin:
|
||||
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]]
|
||||
elif color_by_target:
|
||||
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 1]]
|
||||
else:
|
||||
edge_colors = 'k'
|
||||
|
||||
|
||||
edge_width[edge_width >= 6] = 6
|
||||
# nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
|
||||
# arrows=True, arrowsize=20,
|
||||
# min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0",
|
||||
# ax=ax)
|
||||
# nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.5, edge_labels=edge_labels, ax=ax, rotate=False)
|
||||
|
||||
nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
|
||||
arrows=True, arrowsize=20,
|
||||
min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.025",
|
||||
min_target_margin=25, min_source_margin=25, connectionstyle="arc3, rad=0.025",
|
||||
ax=ax, edge_color=edge_colors)
|
||||
nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True)
|
||||
|
||||
@ -70,11 +86,11 @@ def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = '
|
||||
|
||||
ax.set_xlim(-1.3, 1.3)
|
||||
ax.set_ylim(-1.3, 1.3)
|
||||
# plt.title(title)
|
||||
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', save_str + '.png'), dpi=300)
|
||||
plt.close()
|
||||
|
||||
ax.set_title(title, fontsize=12)
|
||||
def main(base_path):
|
||||
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')):
|
||||
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov'))
|
||||
|
||||
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
|
||||
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
||||
# trial_summary = trial_summary[chirp_notes['good'] == 1]
|
||||
@ -130,7 +146,9 @@ def main(base_path):
|
||||
event_times = np.array(event_times)[time_sorter]
|
||||
event_labels = np.array(event_labels)[time_sorter]
|
||||
|
||||
### create marcov_matrix 1: which beh 2 is triggered by beh. 1 ?
|
||||
marcov_matrix = np.zeros((len(loop_labels)+1, len(loop_labels)+1))
|
||||
### create marcov_matrix 2: beh 2 is triggered by which beh. 1 ?
|
||||
|
||||
for enu_ori, label_ori in enumerate(loop_labels):
|
||||
for enu_tar, label_tar in enumerate(loop_labels):
|
||||
@ -164,14 +182,43 @@ def main(base_path):
|
||||
collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0)
|
||||
collective_event_counts = np.sum(all_event_counts, axis=0)
|
||||
|
||||
|
||||
plot_transition_matrix(collective_marcov_matrix, loop_labels)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||
|
||||
plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100,
|
||||
loop_labels, collective_event_counts, threshold=5, save_str='markov_all')
|
||||
loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]')
|
||||
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300)
|
||||
plt.close()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||
plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100,
|
||||
loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]')
|
||||
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300)
|
||||
plt.close()
|
||||
|
||||
for i, (marcov_matrix, event_counts) in enumerate(zip(all_marcov_matrix, all_event_counts)):
|
||||
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||
|
||||
plot_transition_diagram(
|
||||
marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100,
|
||||
loop_labels, event_counts, ax, threshold=5, color_by_origin=True,
|
||||
title='origin triggers target [%]')
|
||||
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'), dpi=300)
|
||||
plt.close()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
|
||||
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||
plot_transition_diagram(marcov_matrix / event_counts * 100,
|
||||
loop_labels, event_counts, ax, threshold=5, color_by_target=True,
|
||||
title='target triggered by origin [%]')
|
||||
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'), dpi=300)
|
||||
plt.close()
|
||||
|
||||
# for i in range(len(all_marcov_matrix)):
|
||||
# plot_transition_diagram(
|
||||
# all_marcov_matrix[i] / all_event_counts[i].reshape(len(all_event_counts[i]), 1) * 100,
|
||||
# loop_labels, all_event_counts[i], threshold=5)
|
||||
embed()
|
||||
quit()
|
||||
pass
|
||||
|
@ -13,9 +13,23 @@ from event_time_correlations import load_and_converete_boris_events, kde, gauss
|
||||
female_color, male_color = '#e74c3c', '#3498db'
|
||||
|
||||
def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''):
|
||||
# ToDo: finish this !!!
|
||||
iei = []
|
||||
weighted_mean_iei = []
|
||||
median_iei = []
|
||||
for i in range(len(event_times)):
|
||||
iei.append(np.diff(event_times[i]))
|
||||
trial_iei = np.diff(event_times[i][event_times[i] <= 3600*3])
|
||||
iei.append(trial_iei)
|
||||
|
||||
if len(trial_iei) == 0:
|
||||
weighted_mean_iei.append(np.nan)
|
||||
median_iei.append(np.nan)
|
||||
else:
|
||||
weighted_mean_iei.append(np.sum((trial_iei) * trial_iei) / np.sum(trial_iei))
|
||||
median_iei.append(np.median(trial_iei))
|
||||
|
||||
weighted_mean_iei = np.array(weighted_mean_iei)
|
||||
median_iei = np.array(median_iei)
|
||||
|
||||
fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
|
||||
gs = gridspec.GridSpec(2, 2, left=0.1, bottom=0.1, right=0.95, top=0.9)
|
||||
@ -41,13 +55,16 @@ def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''):
|
||||
color, linestyle = female_color, '-'
|
||||
sp = 3
|
||||
|
||||
|
||||
conv_y = np.arange(0, np.percentile(np.hstack(iei), 80), .5)
|
||||
kde_array = kde(iei[i], conv_y, kernal_w=kernal_w, kernal_h=1)
|
||||
|
||||
# kde_array /= np.sum(kde_array)
|
||||
ax[sp].plot(conv_y, kde_array, zorder=2, color=color, linestyle=linestyle, lw=2)
|
||||
|
||||
# ax_m = ax[0].twinx()
|
||||
# ax_m.boxplot([weighted_mean_iei[(win_sex == 'm') & (win_sex == 'm') & ~np.isnan(weighted_mean_iei)],
|
||||
# median_iei[(win_sex == 'm') & (win_sex == 'm') & ~np.isnan(median_iei)]], sym='', vert=False)
|
||||
|
||||
ax[0].set_xlim(conv_y[0], conv_y[-1])
|
||||
ax[0].set_ylabel('KDE', fontsize=12)
|
||||
ax[2].set_ylabel('KDE', fontsize=12)
|
||||
@ -170,7 +187,6 @@ def relative_rate_progression(all_event_t, title=''):
|
||||
|
||||
|
||||
def main(base_path):
|
||||
# ToDo: for chirp and rise analysis different datasets!!!
|
||||
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta')):
|
||||
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta'))
|
||||
|
||||
@ -178,7 +194,6 @@ def main(base_path):
|
||||
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr'))
|
||||
|
||||
|
||||
|
||||
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
|
||||
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
||||
trial_mask = chirp_notes['good'] == 1
|
||||
|
Loading…
Reference in New Issue
Block a user