292 lines
12 KiB
Python
292 lines
12 KiB
Python
import os
|
|
from parser.CellData import CellData
|
|
from parser.DataParserFactory import DatParser
|
|
import numpy as np
|
|
from fitting.ModelFit import ModelFit, get_best_fit
|
|
# from plottools.axes import labelaxes_params
|
|
import matplotlib.pyplot as plt
|
|
from run_Fitter import iget_start_parameters
|
|
from experiments.FiCurve import FICurve, FICurveCellData, FICurveModel
|
|
from Figures_results import scatter_hist
|
|
from my_util.functions import exponential_function
|
|
colors = ["black", "red", "blue", "orange", "green"]
|
|
|
|
|
|
def main():
|
|
# results_dir = "data/final/"
|
|
# for folder in sorted(os.listdir(results_dir)):
|
|
# folder_path = os.path.join(results_dir, folder)
|
|
#
|
|
# if not os.path.isdir(folder_path):
|
|
# continue
|
|
#
|
|
# cell_data = CellData(folder_path)
|
|
# cell_name = cell_data.get_cell_name()
|
|
#
|
|
# fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path)
|
|
#
|
|
# fi_cell.plot_fi_curve(title=cell_name, save_path="temp/cell_fi_curves_images/" + cell_name + "_")
|
|
#
|
|
# steady_state = fi_cell.get_f_inf_frequencies()
|
|
# onset = fi_cell.get_f_zero_frequencies()
|
|
# baseline = fi_cell.get_f_baseline_frequencies()
|
|
# contrasts = fi_cell.stimulus_values
|
|
#
|
|
# headers = ["contrasts", "f_baseline", "f_steady_state", "f_onset"]
|
|
# with open("temp/cell_fi_curves_csvs/" + cell_name + ".csv", 'w') as f:
|
|
# for i in range(len(headers)):
|
|
# if i == 0:
|
|
# f.write(headers[i])
|
|
# else:
|
|
# f.write("," + headers[i])
|
|
# f.write("\n")
|
|
#
|
|
# for i in range(len(contrasts)):
|
|
# f.write(str(contrasts[i]) + ",")
|
|
# f.write(str(baseline[i]) + ",")
|
|
# f.write(str(steady_state[i]) + ",")
|
|
# f.write(str(onset[i]) + "\n")
|
|
# quit()
|
|
|
|
step_response_comparison()
|
|
quit()
|
|
cell_taus = []
|
|
model_taus = []
|
|
|
|
results_dir = "results/sam_cells_only_best/"
|
|
for folder in sorted(os.listdir(results_dir)):
|
|
folder_path = os.path.join(results_dir, folder)
|
|
if not os.path.isdir(folder_path):
|
|
continue
|
|
|
|
fit = get_best_fit(folder_path)
|
|
print(fit.get_fit_routine_error())
|
|
model = fit.get_model()
|
|
cell_data = fit.get_cell_data()
|
|
|
|
fi_model = FICurveModel(model, cell_data.get_fi_contrasts(), cell_data.get_eod_frequency(), save_dir=folder_path)
|
|
tau_model = fi_model.calculate_time_constant(-2)
|
|
model_taus.append(tau_model[1])
|
|
fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path)
|
|
|
|
tau_cell = fi_cell.calculate_time_constant(-2)
|
|
cell_taus.append(tau_cell[1])
|
|
|
|
model_taus_c = []
|
|
cell_taus_c = []
|
|
border = 1
|
|
for i in range(len(model_taus)):
|
|
if np.abs(model_taus[i]) < border and np.abs(cell_taus[i]) < border:
|
|
model_taus_c.append(model_taus[i])
|
|
cell_taus_c.append(cell_taus[i])
|
|
|
|
print("model removed:", len(model_taus) - len(model_taus_c))
|
|
print("cell removed:", len(cell_taus) - len(cell_taus_c))
|
|
|
|
plot_cell_model_comp_taus(cell_taus_c, model_taus_c)
|
|
|
|
# fig, axes = plt.subplots(1, 2, sharey="all", sharex="all")
|
|
#
|
|
# axes[0].hist(model_taus_c)
|
|
# axes[0].set_title("Model taus")
|
|
#
|
|
# axes[1].hist(cell_taus_c)
|
|
# axes[1].set_title("Cell taus")
|
|
#
|
|
# plt.show()
|
|
# plt.close()
|
|
print(model_taus)
|
|
print(cell_taus)
|
|
# sam_tests()
|
|
|
|
|
|
def step_response_comparison():
|
|
results_dir = "results/sam_cells_only_best/"
|
|
for folder in sorted(os.listdir(results_dir)):
|
|
folder_path = os.path.join(results_dir, folder)
|
|
if not os.path.isdir(folder_path):
|
|
continue
|
|
|
|
fit = get_best_fit(folder_path)
|
|
print(fit.get_fit_routine_error())
|
|
model = fit.get_model()
|
|
cell_data = fit.get_cell_data()
|
|
|
|
fi_model = FICurveModel(model, cell_data.get_fi_contrasts(), cell_data.get_eod_frequency(),
|
|
save_dir=folder_path)
|
|
# model_times, model_mean_freqs = fi_model.get_mean_time_and_freq_traces()
|
|
|
|
fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path)
|
|
|
|
contrasts = cell_data.get_fi_contrasts()
|
|
mean_frequencies = cell_data.get_mean_fi_curve_isi_frequencies()
|
|
baseline_freqs = fi_cell.get_f_baseline_frequencies()
|
|
pre_duration = -1 * cell_data.get_recording_times()[0]
|
|
sampling_interval = cell_data.get_sampling_interval()
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
|
|
|
c_contrast_idx = -2
|
|
tau_params, tau_x = fi_cell.__calculate_time_constant_internal__(contrasts[c_contrast_idx], mean_frequencies[c_contrast_idx],
|
|
baseline_freqs[c_contrast_idx], sampling_interval,
|
|
pre_duration, plot=False, plot_data=True)
|
|
|
|
# cell
|
|
x_values = np.arange(len(mean_frequencies[c_contrast_idx])) * sampling_interval - pre_duration
|
|
index_range = (x_values > -0.05) & (x_values < 0.15)
|
|
plot_freq_with_tau_fit(axes[0, 0], x_values[index_range],
|
|
np.array(mean_frequencies[c_contrast_idx])[index_range], tau_x, tau_params)
|
|
axes[0, 0].set_title("2nd highest Contrast: {:.2f}".format(contrasts[c_contrast_idx]))
|
|
axes[0, 0].set_ylabel("Cell")
|
|
axes[0, 0].set_xlabel("Time [s]; tau: {:.3f}".format(tau_params[1]))
|
|
c_contrast_idx = -3
|
|
tau_params, tau_x = fi_cell.__calculate_time_constant_internal__(contrasts[c_contrast_idx], mean_frequencies[c_contrast_idx],
|
|
baseline_freqs[c_contrast_idx], sampling_interval,
|
|
pre_duration, plot=False, plot_data=True)
|
|
x_values = np.arange(len(mean_frequencies[c_contrast_idx])) * sampling_interval - pre_duration
|
|
index_range = (x_values > -0.05) & (x_values < 0.15)
|
|
plot_freq_with_tau_fit(axes[0, 1], x_values[index_range],
|
|
np.array(mean_frequencies[c_contrast_idx])[index_range], tau_x, tau_params)
|
|
axes[0, 1].set_title("3rd highest Contrast: {:.2f}".format(contrasts[c_contrast_idx]))
|
|
axes[0, 1].set_xlabel("Time [s]; tau: {:.3f}".format(tau_params[1]))
|
|
|
|
# model
|
|
contrasts = cell_data.get_fi_contrasts()
|
|
mean_frequencies = fi_model.mean_frequency_traces
|
|
baseline_freqs = fi_model.get_f_baseline_frequencies()
|
|
pre_duration = 0.5
|
|
sampling_interval = fi_model.model.get_sampling_interval()
|
|
|
|
c_contrast_idx = -2
|
|
tau_params, tau_x = fi_cell.__calculate_time_constant_internal__(contrasts[c_contrast_idx],
|
|
mean_frequencies[c_contrast_idx],
|
|
baseline_freqs[c_contrast_idx],
|
|
sampling_interval,
|
|
pre_duration, plot=False, plot_data=True)
|
|
|
|
x_values = np.arange(len(mean_frequencies[c_contrast_idx])) * sampling_interval - pre_duration
|
|
index_range = (x_values > -0.05) & (x_values < 0.15)
|
|
plot_freq_with_tau_fit(axes[1, 0], x_values[index_range],
|
|
np.array(mean_frequencies[c_contrast_idx])[index_range], tau_x, tau_params)
|
|
axes[1, 0].set_ylabel("Model")
|
|
axes[1, 0].set_xlabel("Time [s]; tau: {:.3f}".format(tau_params[1]))
|
|
c_contrast_idx = -3
|
|
tau_params, tau_x = fi_cell.__calculate_time_constant_internal__(contrasts[c_contrast_idx],
|
|
mean_frequencies[c_contrast_idx],
|
|
baseline_freqs[c_contrast_idx],
|
|
sampling_interval,
|
|
pre_duration, plot=False, plot_data=True)
|
|
x_values = np.arange(len(mean_frequencies[c_contrast_idx])) * sampling_interval - pre_duration
|
|
index_range = (x_values > -0.05) & (x_values < 0.15)
|
|
plot_freq_with_tau_fit(axes[1, 1], x_values[index_range],
|
|
np.array(mean_frequencies[c_contrast_idx])[index_range], tau_x, tau_params)
|
|
axes[1, 1].set_xlabel("Time [s]; tau: {:.3f}".format(tau_params[1]))
|
|
|
|
axes[0, 0].set_ylim((0, 800))
|
|
axes[0, 1].set_ylim((0, 800))
|
|
axes[1, 1].set_ylim((0, 800))
|
|
axes[1, 0].set_ylim((0, 800))
|
|
|
|
plt.tight_layout()
|
|
plt.savefig("figures/tau_images/" + cell_data.get_cell_name() + "_tau.png")
|
|
plt.close()
|
|
|
|
|
|
def plot_freq_with_tau_fit(ax, time, freq, tau_x, tau_params):
|
|
ax.plot(time, freq)
|
|
ax.plot(tau_x, exponential_function(tau_x, tau_params[0], tau_params[1], tau_params[2]))
|
|
|
|
|
|
def plot_cell_model_comp_taus(cell_taus, model_taus):
|
|
fig = plt.figure(figsize=(3, 4))
|
|
gs = fig.add_gridspec(2, 1, height_ratios=[3, 7],
|
|
left=0.1, right=0.95, bottom=0.1, top=0.9,
|
|
wspace=0.4, hspace=0.2)
|
|
num_of_bins = 20
|
|
|
|
minimum = min(min(cell_taus), min(model_taus))
|
|
maximum = max(max(cell_taus), max(model_taus))
|
|
step = (maximum - minimum) / num_of_bins
|
|
bins = np.arange(minimum, maximum + step, step)
|
|
|
|
ax = fig.add_subplot(gs[1, 0])
|
|
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
|
|
scatter_hist(cell_taus, model_taus, ax, ax_histx, "Tau Comparison", bins) # , cmap, cell_bursting)
|
|
ax.set_xlabel(r"Cell [s]")
|
|
ax.set_ylabel(r"Model [s]")
|
|
ax_histx.set_ylabel("Count")
|
|
|
|
plt.tight_layout()
|
|
plt.savefig("figures/tau_images/fit_tau_comparison.pdf", transparent=True)
|
|
plt.close()
|
|
|
|
|
|
def sam_tests():
|
|
data_folder = "./data/final_sam/"
|
|
for cell in sorted(os.listdir(data_folder)):
|
|
print(cell)
|
|
cell_folder = os.path.join(data_folder, cell)
|
|
if not os.path.exists(os.path.join(cell_folder, "samspikes1.dat")):
|
|
continue
|
|
|
|
if "2018-05-08-aa-invivo-1" not in cell:
|
|
continue
|
|
|
|
cell_data = CellData(cell_folder)
|
|
sampling_rate = int(round(1 / cell_data.get_sampling_interval()))
|
|
sam_spikes = cell_data.get_sam_spiketimes()
|
|
delta_freqs = cell_data.get_sam_delta_frequencies()
|
|
|
|
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = cell_data.get_sam_traces()
|
|
print(len(time_traces))
|
|
for i in range(len(delta_freqs)):
|
|
if abs(delta_freqs[i]) > 50:
|
|
continue
|
|
fig, axes = plt.subplots(2, 1, sharex="all")
|
|
|
|
axes[0].plot(time_traces[i], local_eod_traces[i])
|
|
axes[0].set_title("Local EOD - dF {}".format(delta_freqs[i]))
|
|
axes[1].plot(time_traces[i], v1_traces[i])
|
|
axes[1].set_title("v1 trace")
|
|
ah_spike = average_spike_height(sam_spikes, v1_traces[i], sampling_rate)
|
|
for j, idx in enumerate(get_x_best(ah_spike)):
|
|
axes[1].eventplot(sam_spikes[idx], lineoffsets=max(v1_traces[i] + 1.5 * (j + 1)),
|
|
colors=colors[j % len(colors)])
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
|
|
def average_spike_height(spike_trains, local_eod, sampling_rate):
|
|
average_height = []
|
|
for spikes_train in spike_trains:
|
|
indices = np.array([s * sampling_rate for s in spikes_train[0]], dtype=np.int)
|
|
local_eod = np.array(local_eod)
|
|
spike_values = [local_eod[i] for i in indices if i < len(local_eod)]
|
|
average_height.append(np.mean(spike_values))
|
|
|
|
return average_height
|
|
|
|
|
|
def get_x_best(average_heights, x=5):
|
|
biggest_idx = []
|
|
biggest_heights = []
|
|
|
|
for i, height in enumerate(average_heights):
|
|
|
|
if len(biggest_idx) < x:
|
|
biggest_idx.append(i)
|
|
biggest_heights.append(height)
|
|
elif height > min(biggest_heights):
|
|
mini = np.argmin(biggest_heights)
|
|
biggest_heights[mini] = height
|
|
biggest_idx[mini] = i
|
|
|
|
biggest_heights, biggest_idx = (list(t) for t in zip(*sorted(zip(biggest_heights, biggest_idx), reverse=True)))
|
|
print(biggest_heights)
|
|
return biggest_idx
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|