P-unit_model/test.py
2021-05-22 13:10:15 +02:00

161 lines
7.4 KiB
Python

import os
from parser.CellData import CellData
from parser.DataParserFactory import DatParser
import numpy as np
from fitting.ModelFit import ModelFit, get_best_fit
# from plottools.axes import labelaxes_params
import matplotlib.pyplot as plt
from run_Fitter import iget_start_parameters
from experiments.FiCurve import FICurve, FICurveCellData, FICurveModel
colors = ["black", "red", "blue", "orange", "green"]
def main():
# results_dir = "data/final/"
# for folder in sorted(os.listdir(results_dir)):
# folder_path = os.path.join(results_dir, folder)
#
# if not os.path.isdir(folder_path):
# continue
#
# cell_data = CellData(folder_path)
# cell_name = cell_data.get_cell_name()
#
# fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path)
#
# fi_cell.plot_fi_curve(title=cell_name, save_path="temp/cell_fi_curves_images/" + cell_name + "_")
#
# steady_state = fi_cell.get_f_inf_frequencies()
# onset = fi_cell.get_f_zero_frequencies()
# baseline = fi_cell.get_f_baseline_frequencies()
# contrasts = fi_cell.stimulus_values
#
# headers = ["contrasts", "f_baseline", "f_steady_state", "f_onset"]
# with open("temp/cell_fi_curves_csvs/" + cell_name + ".csv", 'w') as f:
# for i in range(len(headers)):
# if i == 0:
# f.write(headers[i])
# else:
# f.write("," + headers[i])
# f.write("\n")
#
# for i in range(len(contrasts)):
# f.write(str(contrasts[i]) + ",")
# f.write(str(baseline[i]) + ",")
# f.write(str(steady_state[i]) + ",")
# f.write(str(onset[i]) + "\n")
# quit()
cell_taus = []
model_taus = []
results_dir = "results/sam_cells_only_best/"
for folder in sorted(os.listdir(results_dir)):
folder_path = os.path.join(results_dir, folder)
if not os.path.isdir(folder_path):
continue
fit = get_best_fit(folder_path)
print(fit.get_fit_routine_error())
model = fit.get_model()
cell_data = fit.get_cell_data()
fi_model = FICurveModel(model, cell_data.get_fi_contrasts(), cell_data.get_eod_frequency())
tau_model = fi_model.calculate_time_constant(-2)
model_taus.append(tau_model)
fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path)
tau_cell = fi_cell.calculate_time_constant(-2)
cell_taus.append(tau_cell)
# model_taus = [0.008227050473746214, 339.82706244279075, 0.010807838358313856, 0.01115826226335211, 0.007413613528371537, 0.013213123673467943, 0.010808781901437248, 0.0014254019917934319, 0.015448860984264491, 0.014413888046967265, 0.029301687421672096, 255.82969629640462, 0.00457130444591641, 0.009463250852321902, 0.007755615618900141, 0.009110183466482135, 0.007225102891006319, 0.0024319255218167336, 0.017420779742227246, 0.027195130905873905, 0.00934661249103802, 0.07158177921097474, 0.004866423936911278, 0.0008792730042370866, 0.00820470663372859, 0.05135988132772797, -945.8805502129879, -625.3981095962032, 0.00045249542468299257, 0.10198296886109447, 0.02992101543230009, 715.8802825637086, 0.0074281010613263775, 0.002038042609377947, 0.0055331475878047445, 0.010965819934792512, 0.00916015878530846, -123.0502556160885, 0.013734214511572751, 0.004193114169578979, 0.011103783836162914, 0.018070119202374276]
# cell_taus = [0.0035588022114672975, 0.005541599918212267, 0.007848670525682807, 0.008147461940299978, 0.005948699597158819, 0.0024739217090879104, 0.0038303906688137847, 0.00300889313116284, 0.014167509501882801, 0.009459132581703281, 0.005226151863380407, 772.607757547133, 0.0016936075127979523, 0.008768601246126134, 0.0036987681597240958, 0.009306705661392982, 0.004808427175831087, 0.005419130192821167, 0.0028735071877832733, 0.005983916198767454, 0.004369124640159074, 0.020115307489662095, 468.1810372271939, 0.0012946259647070454, 0.0021810924044437753, 259.6701021041893, 2891.7659169677813, -2155.469810882238, 0.0027895996432137117, 0.01503608591999554, 1138.5941497875147, -0.009831620851536924, 0.004657794528111363, -0.007131468820451661, -0.0221455330638256, -589.1530734507537, -506.6077728634018, -0.0028166760486066605, 359.3395355603788, -0.003053762369811596, 0.00465946355831796, 0.01675427242298042]
model_taus_c = [v for v in model_taus if np.abs(v) < 0.15]
cell_taus_c = [v for v in cell_taus if np.abs(v) < 0.15]
print("model removed:", len(model_taus) - len(model_taus_c))
print("cell removed:", len(cell_taus) - len(cell_taus_c))
fig, axes = plt.subplots(1, 2, sharey="all", sharex="all")
axes[0].hist(model_taus_c)
axes[0].set_title("Model taus")
axes[1].hist(cell_taus_c)
axes[1].set_title("Cell taus")
plt.show()
plt.close()
print(model_taus)
print(cell_taus)
# sam_tests()
def sam_tests():
data_folder = "./data/final_sam/"
for cell in sorted(os.listdir(data_folder)):
print(cell)
cell_folder = os.path.join(data_folder, cell)
if not os.path.exists(os.path.join(cell_folder, "samspikes1.dat")):
continue
if "2018-05-08-aa-invivo-1" not in cell:
continue
cell_data = CellData(cell_folder)
sampling_rate = int(round(1 / cell_data.get_sampling_interval()))
sam_spikes = cell_data.get_sam_spiketimes()
delta_freqs = cell_data.get_sam_delta_frequencies()
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = cell_data.get_sam_traces()
print(len(time_traces))
for i in range(len(delta_freqs)):
if abs(delta_freqs[i]) > 50:
continue
fig, axes = plt.subplots(2, 1, sharex="all")
axes[0].plot(time_traces[i], local_eod_traces[i])
axes[0].set_title("Local EOD - dF {}".format(delta_freqs[i]))
axes[1].plot(time_traces[i], v1_traces[i])
axes[1].set_title("v1 trace")
ah_spike = average_spike_height(sam_spikes, v1_traces[i], sampling_rate)
for j, idx in enumerate(get_x_best(ah_spike)):
axes[1].eventplot(sam_spikes[idx], lineoffsets=max(v1_traces[i] + 1.5 * (j + 1)),
colors=colors[j % len(colors)])
plt.show()
plt.close()
def average_spike_height(spike_trains, local_eod, sampling_rate):
average_height = []
for spikes_train in spike_trains:
indices = np.array([s * sampling_rate for s in spikes_train[0]], dtype=np.int)
local_eod = np.array(local_eod)
spike_values = [local_eod[i] for i in indices if i < len(local_eod)]
average_height.append(np.mean(spike_values))
return average_height
def get_x_best(average_heights, x=5):
biggest_idx = []
biggest_heights = []
for i, height in enumerate(average_heights):
if len(biggest_idx) < x:
biggest_idx.append(i)
biggest_heights.append(height)
elif height > min(biggest_heights):
mini = np.argmin(biggest_heights)
biggest_heights[mini] = height
biggest_idx[mini] = i
biggest_heights, biggest_idx = (list(t) for t in zip(*sorted(zip(biggest_heights, biggest_idx), reverse=True)))
print(biggest_heights)
return biggest_idx
if __name__ == '__main__':
main()