This commit is contained in:
Till Raab 2023-06-19 14:33:32 +02:00
parent 988d80f7bd
commit 3a4ece9902
2 changed files with 90 additions and 28 deletions

View File

@ -2,6 +2,7 @@ import os
import sys
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import pandas as pd
import scipy.stats as scp
@ -13,21 +14,36 @@ glob_colors = ['#BA2D22', '#53379B', '#F47F17', '#3673A4', '#AAB71B', '#DC143C',
def plot_transition_matrix(matrix, labels):
fig, ax = plt.subplots()
fig = plt.figure(figsize=(20/2.54, 20/2.54))
#gs = gridspec.GridSpec(1, 2, left=0.1, bottom=0.1, right=0.9, top=0.95, wspace=0.1, width_ratios=[8, 1])
gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.925, top=0.95)
ax = fig.add_subplot(gs[0, 0])
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
# cax = fig.add_subplot(gs[0, 1])
im = ax.imshow(matrix)
ax.set_xticks(list(range(len(matrix))))
ax.set_yticks(list(range(len(matrix))))
ax.set_xticklabels(labels)
ax.set_xticklabels(labels, rotation=45)
ax.set_yticklabels(labels)
fig.colorbar(im)
plt.show()
fig.colorbar(im, cax=cax, orientation='vertical')
ax.tick_params(labelsize=10)
cax.tick_params(labelsize=10)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'event_counts' + '.png'), dpi=300)
plt.close()
def plot_transition_diagram(matrix, labels, node_size, ax, threshold=5,
color_by_origin=False, color_by_target=False, title=''):
def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = 'rdm'):
matrix[matrix <= threshold] = 0
matrix = np.around(matrix, decimals=1)
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
Graph = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
node_labels = dict(zip(Graph, labels))
@ -42,24 +58,24 @@ def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = '
# ToDo: nodes
nx.draw_networkx_nodes(Graph, pos=positions, node_size=node_size, ax=ax, alpha=0.5, node_color=np.array(glob_colors)[:len(node_size)])
nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels)
nx.draw_networkx_labels(Graph, pos=positions2, labels=node_labels, ax=ax)
# google networkx drawing to get better graphs with networkx
# nx.draw(Graph, pos=positions, node_size=node_size, label=labels, with_labels=True, ax=ax)
# # ToDo: edges
edge_width = np.array([x / 5 for x in [*edge_labels.values()]])
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]]
if color_by_origin:
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 0]]
elif color_by_target:
edge_colors = np.array(glob_colors)[np.array([*edge_labels.keys()], dtype=int)[:, 1]]
else:
edge_colors = 'k'
edge_width[edge_width >= 6] = 6
# nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
# arrows=True, arrowsize=20,
# min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.0",
# ax=ax)
# nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.5, edge_labels=edge_labels, ax=ax, rotate=False)
nx.draw_networkx_edges(Graph, pos=positions, node_size=node_size, width=edge_width,
arrows=True, arrowsize=20,
min_target_margin=25, min_source_margin=10, connectionstyle="arc3, rad=0.025",
min_target_margin=25, min_source_margin=25, connectionstyle="arc3, rad=0.025",
ax=ax, edge_color=edge_colors)
nx.draw_networkx_edge_labels(Graph, positions, label_pos=0.2, edge_labels=edge_labels, ax=ax, rotate=True)
@ -70,11 +86,11 @@ def plot_transition_diagram(matrix, labels, node_size, threshold=5, save_str = '
ax.set_xlim(-1.3, 1.3)
ax.set_ylim(-1.3, 1.3)
# plt.title(title)
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', save_str + '.png'), dpi=300)
plt.close()
ax.set_title(title, fontsize=12)
def main(base_path):
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov'))
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
# trial_summary = trial_summary[chirp_notes['good'] == 1]
@ -130,7 +146,9 @@ def main(base_path):
event_times = np.array(event_times)[time_sorter]
event_labels = np.array(event_labels)[time_sorter]
### create marcov_matrix 1: which beh 2 is triggered by beh. 1 ?
marcov_matrix = np.zeros((len(loop_labels)+1, len(loop_labels)+1))
### create marcov_matrix 2: beh 2 is triggered by which beh. 1 ?
for enu_ori, label_ori in enumerate(loop_labels):
for enu_tar, label_tar in enumerate(loop_labels):
@ -164,14 +182,43 @@ def main(base_path):
collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0)
collective_event_counts = np.sum(all_event_counts, axis=0)
plot_transition_matrix(collective_marcov_matrix, loop_labels)
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100,
loop_labels, collective_event_counts, threshold=5, save_str='markov_all')
loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300)
plt.close()
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100,
loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300)
plt.close()
for i, (marcov_matrix, event_counts) in enumerate(zip(all_marcov_matrix, all_event_counts)):
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(
marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100,
loop_labels, event_counts, ax, threshold=5, color_by_origin=True,
title='origin triggers target [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'), dpi=300)
plt.close()
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(marcov_matrix / event_counts * 100,
loop_labels, event_counts, ax, threshold=5, color_by_target=True,
title='target triggered by origin [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'), dpi=300)
plt.close()
# for i in range(len(all_marcov_matrix)):
# plot_transition_diagram(
# all_marcov_matrix[i] / all_event_counts[i].reshape(len(all_event_counts[i]), 1) * 100,
# loop_labels, all_event_counts[i], threshold=5)
embed()
quit()
pass

View File

@ -13,9 +13,23 @@ from event_time_correlations import load_and_converete_boris_events, kde, gauss
female_color, male_color = '#e74c3c', '#3498db'
def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''):
# ToDo: finish this !!!
iei = []
weighted_mean_iei = []
median_iei = []
for i in range(len(event_times)):
iei.append(np.diff(event_times[i]))
trial_iei = np.diff(event_times[i][event_times[i] <= 3600*3])
iei.append(trial_iei)
if len(trial_iei) == 0:
weighted_mean_iei.append(np.nan)
median_iei.append(np.nan)
else:
weighted_mean_iei.append(np.sum((trial_iei) * trial_iei) / np.sum(trial_iei))
median_iei.append(np.median(trial_iei))
weighted_mean_iei = np.array(weighted_mean_iei)
median_iei = np.array(median_iei)
fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
gs = gridspec.GridSpec(2, 2, left=0.1, bottom=0.1, right=0.95, top=0.9)
@ -41,13 +55,16 @@ def iei_analysis(event_times, win_sex, lose_sex, kernal_w, title=''):
color, linestyle = female_color, '-'
sp = 3
conv_y = np.arange(0, np.percentile(np.hstack(iei), 80), .5)
kde_array = kde(iei[i], conv_y, kernal_w=kernal_w, kernal_h=1)
# kde_array /= np.sum(kde_array)
ax[sp].plot(conv_y, kde_array, zorder=2, color=color, linestyle=linestyle, lw=2)
# ax_m = ax[0].twinx()
# ax_m.boxplot([weighted_mean_iei[(win_sex == 'm') & (win_sex == 'm') & ~np.isnan(weighted_mean_iei)],
# median_iei[(win_sex == 'm') & (win_sex == 'm') & ~np.isnan(median_iei)]], sym='', vert=False)
ax[0].set_xlim(conv_y[0], conv_y[-1])
ax[0].set_ylabel('KDE', fontsize=12)
ax[2].set_ylabel('KDE', fontsize=12)
@ -170,7 +187,6 @@ def relative_rate_progression(all_event_t, title=''):
def main(base_path):
# ToDo: for chirp and rise analysis different datasets!!!
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_meta'))
@ -178,7 +194,6 @@ def main(base_path):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'event_time_corr'))
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
trial_mask = chirp_notes['good'] == 1