examples are nice... make a function, clean it up

This commit is contained in:
Till Raab 2023-08-09 13:05:00 +02:00
parent 4c3bd818b9
commit 48314b363f

View File

@ -2,6 +2,8 @@ import os
import sys import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -87,6 +89,7 @@ def plot_transition_diagram(matrix, labels, node_size, ax, threshold=5,
ax.set_xlim(-1.3, 1.3) ax.set_xlim(-1.3, 1.3)
ax.set_ylim(-1.3, 1.3) ax.set_ylim(-1.3, 1.3)
ax.set_title(title, fontsize=12) ax.set_title(title, fontsize=12)
def main(base_path): def main(base_path):
if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')): if not os.path.exists(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')):
os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov')) os.makedirs(os.path.join(os.path.split(__file__)[0], 'figures', 'markov'))
@ -103,10 +106,20 @@ def main(base_path):
# agonistic categorie plot # agonistic categorie plot
fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54)) fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54))
gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.95, top=0.95) gs = gridspec.GridSpec(2, 1, left=0.1, bottom=0.1, right=0.9, top=0.95, height_ratios=[1, 4], hspace=0)
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[1, 0])
ax_spec = fig.add_subplot(gs[0, 0], sharex=ax)
plt.setp(ax_spec.get_xticklabels(), visible=False)
for i in range(1, 5):
ax.fill_between([0, 4], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey')
ax.fill_between([5, 10], np.array([-.2, -.2]) + i, np.array([.2, .2]) + i, color='tab:grey')
fill_dots = np.arange(4, 5.1, 0.125)
ax.plot(fill_dots, np.ones_like(fill_dots)*i, '.', color='tab:grey', markersize=3)
got_examples = [False, False, False, False] got_examples = [False, False, False, False]
example_skips = [3, 4, 1, 0] example_skips = [3, 4, 3, 0]
for index, trial in trial_summary.iterrows(): for index, trial in trial_summary.iterrows():
trial_path = os.path.join(base_path, trial['recording']) trial_path = os.path.join(base_path, trial['recording'])
@ -209,7 +222,8 @@ def main(base_path):
rise_before = True rise_before = True
if np.any( ((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)): if np.any( ((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)):
chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)] # chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chirp_dt) & ((chirp_times[1] - chase_off_time) < max_dt)]
chirp_time_oi = chirp_times[1][((chase_off_time - chirp_times[1]) < chase_dur) & ((chirp_times[1] - chase_off_time) < max_dt)]
chirp_arround_end = True chirp_arround_end = True
if rise_before: if rise_before:
@ -227,9 +241,6 @@ def main(base_path):
if chase_dur > 10: if chase_dur > 10:
if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0): if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0):
if example_skips[int(agonitic_categorie[enu] - 1)] == 0: if example_skips[int(agonitic_categorie[enu] - 1)] == 0:
ax.fill_between([0, 10],
np.array([-.2, -.2]) + agonitic_categorie[enu],
np.array([.2, .2]) + agonitic_categorie[enu], color='tab:grey')
for ct in chirp_time_oi: for ct in chirp_time_oi:
ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10], ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10],
[agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='k', lw=2) [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='k', lw=2)
@ -239,65 +250,100 @@ def main(base_path):
got_examples[0] = True got_examples[0] = True
else: else:
example_skips[int(agonitic_categorie[enu] - 1)] -= 1 example_skips[int(agonitic_categorie[enu] - 1)] -= 1
elif agonitic_categorie[enu] == 2 and not got_examples[1]: elif agonitic_categorie[enu] == 2 and not got_examples[1]:
if chase_dur > 10: if chase_dur > 10:
if example_skips[int(agonitic_categorie[enu] - 1)] == 0: if example_skips[int(agonitic_categorie[enu] - 1)] == 0:
ax.fill_between([0, 10],
np.array([-.2, -.2]) + agonitic_categorie[enu],
np.array([.2, .2]) + agonitic_categorie[enu], color='tab:grey')
for rt in rise_times_oi: for rt in rise_times_oi:
ax.plot([rt - chase_on_time, rt - chase_on_time], ax.plot([rt - chase_on_time, rt - chase_on_time],
[agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='firebrick', lw=2) [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='firebrick', lw=2)
got_examples[1] = True got_examples[1] = True
else: else:
example_skips[int(agonitic_categorie[enu] - 1)] -= 1 example_skips[int(agonitic_categorie[enu] - 1)] -= 1
elif agonitic_categorie[enu] == 3 and not got_examples[2]: elif agonitic_categorie[enu] == 3 and not got_examples[2]:
if chase_dur > 10: if chase_dur > 10:
if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0): if np.any((chirp_time_oi - chase_off_time) < 0) and np.any((chirp_time_oi - chase_off_time) > 0):
if example_skips[int(agonitic_categorie[enu] - 1)] == 0: if example_skips[int(agonitic_categorie[enu] - 1)] == 0:
ax.fill_between([0, 10],
np.array([-.2, -.2]) + agonitic_categorie[enu],
np.array([.2, .2]) + agonitic_categorie[enu], color='tab:grey')
for ct in chirp_time_oi: for ct in chirp_time_oi:
ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10], ax.plot([ct - chase_off_time + 10, ct - chase_off_time + 10],
[agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='k', lw=2) [agonitic_categorie[enu] - .2, agonitic_categorie[enu] + .2], color='k', lw=2)
got_examples[2] = True got_examples[2] = True
else: else:
example_skips[int(agonitic_categorie[enu] - 1)] -= 1 example_skips[int(agonitic_categorie[enu] - 1)] -= 1
elif agonitic_categorie[enu] == 4 and not got_examples[3]:
if chase_dur > 10:
ax.fill_between([0, 10],
np.array([-.2, -.2]) + agonitic_categorie[enu],
np.array([.2, .2]) + agonitic_categorie[enu], color='tab:grey')
got_examples[3] = True
else: else:
pass pass
for i in range(4): ### agonistic categories
stacked_agonistic_categories = np.hstack(all_agonistic_categorie)
stacked_all_chase_durs = np.hstack(all_chase_durs)
ax.plot([0, 0], [0, 5], '--', color='k', lw=1) pct_each_categorie = np.zeros(4)
ax.plot([10, 10], [0, 5], '--', color='k', lw=1) for enu, cat in enumerate(range(1, 5)):
ax.set_ylim(0.5, 4.5) pct_each_categorie[enu] = len(stacked_agonistic_categories[stacked_agonistic_categories == cat]) / len(stacked_agonistic_categories)
# example plot
for enu, cat_pct in enumerate(pct_each_categorie):
ax.text(15.2, enu+1, f'{cat_pct*100:.1f}' + ' $\%$', clip_on=False, fontsize=14, ha='left', va='center')
ax.plot([0, 0], [0.8, 5], '--', color='k', lw=1)
ax.plot([10, 10], [0.8, 5], '--', color='k', lw=1)
ax.set_ylim(0.25, 4.5)
ax.set_xlim(-5, 15)
ax.set_yticks([1, 2, 3, 4])
# ax.set_yticklabels([r'rise$_{pre}$ $&$ chirp$_{end}$', r'only rise$_{pre}$', r'only chirp$_{end}$', 'no communication'])
ax.set_yticklabels(['A ', 'B ', 'C ', 'D '])
ax.invert_yaxis()
ax.set_xlabel('time [s]', fontsize=12)
ax.tick_params(axis='y', labelsize=20)
ax.tick_params(axis = 'x', labelsize=10)
legend_elements = [Line2D([0], [0], color='firebrick', lw=2, label=r'rise$_{lose}$'),
Line2D([0], [0], color='k', lw=2, label=r'chirp$_{lose}$'),
Patch(facecolor='tab:grey', edgecolor='w', label= 'chase event')]
ax.legend(handles=legend_elements, loc='upper right', ncol=3, bbox_to_anchor=(1, 1), frameon=False, fontsize=10, facecolor='white')
ax.spines[['right', 'top']].set_visible(False)
plt.show() plt.show()
embed()
quit() # bar plot
fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54))
ax.bar(np.arange(4),
[len(stacked_agonistic_categories[stacked_agonistic_categories == 1]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 2]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 3]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 4])])
ax.set_xticks(np.arange(4))
ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _'])
plt.show()
# pct
pct_agon_categorie = np.zeros(shape=(len(all_agonistic_categorie), 4))
for enu, agonitic_categorie in enumerate(all_agonistic_categorie):
for cat in np.arange(4):
pct_agon_categorie[enu, cat] = len(agonitic_categorie[agonitic_categorie == cat+1]) / len(agonitic_categorie)
fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54))
ax.bar(np.arange(4), pct_agon_categorie.mean(0))
ax.errorbar(np.arange(4), pct_agon_categorie.mean(0), yerr=pct_agon_categorie.std(0), fmt='', color='k', linestyle='None')
ax.set_xticks(np.arange(4))
ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _'])
plt.show()
### marcov models
all_marcov_matrix = np.array(all_marcov_matrix) all_marcov_matrix = np.array(all_marcov_matrix)
all_event_counts = np.array(all_event_counts) all_event_counts = np.array(all_event_counts)
collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0) collective_marcov_matrix = np.sum(all_marcov_matrix, axis=0)
collective_event_counts = np.sum(all_event_counts, axis=0) collective_event_counts = np.sum(all_event_counts, axis=0)
plot_transition_matrix(collective_marcov_matrix, loop_labels) plot_transition_matrix(collective_marcov_matrix, loop_labels)
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100, plot_transition_diagram(
collective_marcov_matrix / collective_event_counts.reshape(len(collective_event_counts), 1) * 100,
loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]') loop_labels, collective_event_counts, ax, threshold=5, color_by_origin=True, title='origin triggers target [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300) plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_destination' + '.png'), dpi=300)
plt.close() plt.close()
@ -305,7 +351,8 @@ def main(base_path):
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) fig.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100, plot_transition_diagram(collective_marcov_matrix / collective_event_counts * 100,
loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True, title='target triggered by origin [%]') loop_labels, collective_event_counts, ax, threshold=5, color_by_target=True,
title='target triggered by origin [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300) plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', 'markov_origin' + '.png'), dpi=300)
plt.close() plt.close()
@ -317,7 +364,8 @@ def main(base_path):
marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100, marcov_matrix / event_counts.reshape(len(event_counts), 1) * 100,
loop_labels, event_counts, ax, threshold=5, color_by_origin=True, loop_labels, event_counts, ax, threshold=5, color_by_origin=True,
title='origin triggers target [%]') title='origin triggers target [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'), dpi=300) plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_destination' + '.png'),
dpi=300)
plt.close() plt.close()
fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54)) fig, ax = plt.subplots(figsize=(21 / 2.54, 19 / 2.54))
@ -325,58 +373,12 @@ def main(base_path):
plot_transition_diagram(marcov_matrix / event_counts * 100, plot_transition_diagram(marcov_matrix / event_counts * 100,
loop_labels, event_counts, ax, threshold=5, color_by_target=True, loop_labels, event_counts, ax, threshold=5, color_by_target=True,
title='target triggered by origin [%]') title='target triggered by origin [%]')
plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'), dpi=300) plt.savefig(os.path.join(os.path.split(__file__)[0], 'figures', 'markov', f'markov_{i}_origin' + '.png'),
dpi=300)
plt.close() plt.close()
embed() embed()
quit() quit()
### agonistic categories
# stacked
stacked_agonistic_categories = np.hstack(all_agonistic_categorie)
stacked_all_chase_durs = np.hstack(all_chase_durs)
# idx_cat_4 = np.where(stacked_agonistic_categories == 4)[0]
# idx_cat_4 = idx_cat_4[np.argsort(stacked_all_chase_durs[idx_cat_4])]
# idx_cat_3 = np.where(stacked_agonistic_categories == 3)[0]
# idx_cat_3 = idx_cat_3[np.argsort(stacked_all_chase_durs[idx_cat_3])]
# idx_cat_2 = np.where(stacked_agonistic_categories == 2)[0]
# idx_cat_2 = idx_cat_2[np.argsort(stacked_all_chase_durs[idx_cat_2])]
# idx_cat_1 = np.where(stacked_agonistic_categories == 1)[0]
# idx_cat_1 = idx_cat_1[np.argsort(stacked_all_chase_durs[idx_cat_1])]
#
# fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54))
# ax.plot(10 - stacked_all_chase_durs[idx_cat_4], np.arange(len(idx_cat_4)))
# ax.plot(10 - stacked_all_chase_durs[idx_cat_3], np.arange(len(idx_cat_4), len(idx_cat_3) + len(idx_cat_4)))
# ax.plot(10 - stacked_all_chase_durs[idx_cat_2], np.arange(len(idx_cat_3) + len(idx_cat_4), len(idx_cat_2) + len(idx_cat_3) + len(idx_cat_4)))
# ax.plot(10 - stacked_all_chase_durs[idx_cat_1], np.arange(len(idx_cat_2) + len(idx_cat_3) + len(idx_cat_4), len(idx_cat_1) + len(idx_cat_2) + len(idx_cat_3) + len(idx_cat_4)))
# ax.set_xlim(0, 10)
# plt.show()
fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54))
ax.bar(np.arange(4),
[len(stacked_agonistic_categories[stacked_agonistic_categories == 1]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 2]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 3]),
len(stacked_agonistic_categories[stacked_agonistic_categories == 4])])
ax.set_xticks(np.arange(4))
ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _'])
plt.show()
# pct
pct_agon_categorie = np.zeros(shape=(len(all_agonistic_categorie), 4))
for enu, agonitic_categorie in enumerate(all_agonistic_categorie):
for cat in np.arange(4):
pct_agon_categorie[enu, cat] = len(agonitic_categorie[agonitic_categorie == cat+1]) / len(agonitic_categorie)
fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54))
ax.bar(np.arange(4), pct_agon_categorie.mean(0))
ax.errorbar(np.arange(4), pct_agon_categorie.mean(0), yerr=pct_agon_categorie.std(0), fmt='', color='k', linestyle='None')
ax.set_xticks(np.arange(4))
ax.set_xticklabels([r'rise$_{pre}$ + chirp$_{end}$', r'rise$_{pre}$ + _', r'_ + chirp$_{end}$', '_ + _'])
plt.show()
pass pass