P-unit_model/sam_experiments.py
2020-08-19 17:18:16 +02:00

273 lines
10 KiB
Python

from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus as SAM
from Baseline import get_baseline_class
from FiCurve import FICurveModel
from models.LIFACnoise import LifacNoiseModel
import numpy as np
import matplotlib.pyplot as plt
import helperFunctions as hF
from CellData import CellData
from ModelFit import ModelFit, get_best_fit
import os
def main():
sam_analysis("results/invivo_results/2013-01-08-ad-invivo-1/")
quit()
modelfit = get_best_fit("results/invivo_results/2013-01-08-ad-invivo-1/", use_comparable_error=False)
if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")):
print("Cell: {} \n Has no measured sam stimuli.")
return
cell_data = CellData(modelfit.get_cell_path())
eod_freq = cell_data.get_eod_frequency()
model = modelfit.get_model()
# base_cell = get_baseline_class(cell_data)
# base_model = get_baseline_class(model, cell_data.get_eod_frequency())
# isis_cell = np.array(base_cell.get_interspike_intervals()) * 1000
# isi_model = np.array(base_model.get_interspike_intervals()) * 1000
# bins = np.arange(0, 20, 0.1)
# plt.hist(isi_model, bins=bins, alpha=0.5)
# plt.hist(isis_cell, bins=bins, alpha=0.5)
# plt.show()
# plt.close()
# ficurve = FICurveModel(model, np.arange(-1, 1.1, 0.1), eod_freq)
#
# ficurve.plot_fi_curve()
durations = cell_data.get_sam_durations()
u_durations = np.unique(durations)
mean_duration = np.mean(durations)
contrasts = cell_data.get_sam_contrasts()
u_contrasts = np.unique(contrasts)
contrast = contrasts[0] # are all the same in this test case
spiketimes = cell_data.get_sam_spiketimes()
delta_freqs = cell_data.get_sam_delta_frequencies()
step_size = cell_data.get_sampling_interval()
spikes_dictionary = {}
for i, m_freq in enumerate(delta_freqs):
if m_freq in spikes_dictionary:
spikes_dictionary[m_freq].append(spiketimes[i])
else:
spikes_dictionary[m_freq] = [spiketimes[i]]
for m_freq in sorted(spikes_dictionary.keys()):
if mean_duration < 2*1/float(m_freq):
continue
stimulus = SAM(eod_freq, contrast/100, m_freq)
v1, spikes_model = model.simulate(stimulus, mean_duration * 4)
prob_density_function_model = spiketimes_calculate_pdf(spikes_model, step_size)
# plt.plot(prob_density_function_model)
# plt.show()
# plt.close()
fig, axes = plt.subplots(1, 4)
cuts = cut_pdf_into_periods(prob_density_function_model, 1/float(m_freq), step_size)
for c in cuts:
axes[0].plot(c, color="gray", alpha=0.2)
axes[0].set_title("model")
mean_model = np.mean(cuts, axis=0)
axes[0].plot(mean_model, color="black")
means_cell = []
for spikes_cell in spikes_dictionary[m_freq]:
prob_density_cell = spiketimes_calculate_pdf(spikes_cell[0], step_size)
if len(prob_density_cell) < 3 * (eod_freq / step_size):
continue
cuts_cell = cut_pdf_into_periods(prob_density_cell, 1/float(m_freq), step_size)
for c in cuts_cell:
axes[1].plot(c, color="gray", alpha=0.15)
print(cuts_cell.shape)
means_cell.append(np.mean(cuts_cell, axis=0))
if len(means_cell) == 0:
continue
means_cell = np.array(means_cell)
total_mean_cell = np.mean(means_cell, axis=0)
axes[1].set_title("cell")
axes[1].plot(total_mean_cell, color="black")
axes[2].set_title("difference")
diff = [(total_mean_cell[i]-mean_model[i]) for i in range(len(total_mean_cell))]
axes[2].plot(diff)
axes[3].plot(total_mean_cell)
axes[3].plot(mean_model)
plt.show()
plt.close()
def sam_analysis(fit_path):
modelfit = get_best_fit(fit_path)
if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")):
print("Cell: {} \n Has no measured sam stimuli.")
return
cell_data = CellData(modelfit.get_cell_path())
model = modelfit.get_model()
# parameters = {'delta_a': 0.08820130374685671, 'refractory_period': 0.0006, 'a_zero': 15, 'step_size': 5e-05,
# 'v_base': 0, 'noise_strength': 0.03622523883042496, 'v_zero': 0, 'threshold': 1,
# 'input_scaling': 77.75367190909581, 'tau_a': 0.07623731247799118, 'v_offset': -10.546875,
# 'mem_tau': 0.008741976196676469, 'dend_tau': 0.0012058986118892773}
# model = LifacNoiseModel(parameters)
# cell_data = CellData("./data/test_data/2012-12-13-an-invivo-1/")
eod_freq = cell_data.get_eod_frequency()
step_size = cell_data.get_sampling_interval()
durations = cell_data.get_sam_durations()
u_durations = np.unique(durations)
contrasts = cell_data.get_sam_contrasts()
u_contrasts = np.unique(contrasts)
spiketimes = cell_data.get_sam_spiketimes()
delta_freqs = cell_data.get_sam_delta_frequencies()
u_delta_freqs = np.unique(delta_freqs)
all_data = []
for mod_freq in sorted(u_delta_freqs):
# TODO problem of cutting the pdf as in some cases the pdf is shorter than 1 modulation frequency period!
if 1/mod_freq > durations[0] / 4:
print("skipped mod_freq: {}".format(mod_freq))
print("Duration: {} while mod_freq period: {:.2f}".format(durations[0], 1/mod_freq))
print("Maybe long enough duration? unique durations:", u_durations)
continue
mfreq_data = {}
cell_means = []
model_means = []
for c in u_contrasts:
mfreq_data[c] = []
for i in range(len(delta_freqs)):
if delta_freqs[i] != mod_freq:
continue
if len(spiketimes[i]) > 1:
print("There are more spiketimes in one 'point'! Only the first was used! ")
spikes = spiketimes[i][0]
cell_pdf = spiketimes_calculate_pdf(spikes, step_size)
cell_cuts = cut_pdf_into_periods(cell_pdf, 1/mod_freq, step_size, factor=1.1, use_all=True)
cell_mean = np.mean(cell_cuts, axis=0)
cell_means.append(cell_mean)
# fig, axes = plt.subplots(1, 2)
# for c in cell_cuts:
# axes[0].plot(c, color="grey", alpha=0.2)
# axes[0].plot(np.mean(cell_means, axis=0), color="black")
stimulus = SAM(eod_freq, contrasts[i] / 100, mod_freq)
v1, spikes_model = model.simulate(stimulus, durations[i] * 4)
model_pdf = spiketimes_calculate_pdf(spikes_model, step_size)
model_cuts = cut_pdf_into_periods(model_pdf, 1/mod_freq, step_size, factor=1.1)
model_mean = np.mean(model_cuts, axis=0)
model_means.append(model_mean)
# for c in model_cuts:
# axes[1].plot(c, color="grey", alpha=0.2)
# axes[1].plot(np.mean(model_cuts, axis=0), color="black")
# plt.title("mod_freq: {}".format(mod_freq))
# plt.show()
# plt.close()
fig, axes = plt.subplots(1, 4)
for c in cell_means:
axes[0].plot(c, color="grey", alpha=0.2)
axes[0].plot(np.mean(cell_means, axis=0), color="black")
axis_cell = axes[0].axis()
for m in model_means:
axes[1].plot(m, color="grey", alpha=0.2)
axes[1].plot(np.mean(model_means, axis=0), color="black")
axis_model = axes[1].axis()
ylim_top = max(axis_cell[3], axis_model[3])
axes[1].set_ylim(0, ylim_top)
axes[0].set_ylim(0, ylim_top)
axes[2].plot((np.mean(model_means, axis=0) - np.mean(cell_means, axis=0)) / np.mean(model_means, axis=0))
plt.title("modulation frequency: {}".format(mod_freq))
plt.show()
plt.close()
def generate_pdf(model, stimulus, trials=4, sim_length=3, kernel_width=0.005):
trials_rate_list = []
step_size = model.get_parameters()["step_size"]
for _ in range(trials):
v1, spikes = model.simulate_slow(stimulus, total_time_s=sim_length)
binary = np.zeros(int(sim_length/step_size))
spikes = [int(s / step_size) for s in spikes]
for s_idx in spikes:
binary[s_idx] = 1
kernel = gaussian_kernel(kernel_width, step_size)
rate = np.convolve(binary, kernel, mode='same')
trials_rate_list.append(rate)
times = [np.arange(0, sim_length, step_size) for _ in range(trials)]
t, mean_rate = hF.calculate_mean_of_frequency_traces(times, trials_rate_list, step_size)
return mean_rate
def spiketimes_calculate_pdf(spikes, step_size, kernel_width=0.005):
length = int(spikes[len(spikes)-1] / step_size)+1
binary = np.zeros(length)
spikes = [int(s / step_size) for s in spikes]
for s_idx in spikes:
binary[s_idx] = 1
kernel = gaussian_kernel(kernel_width, step_size)
rate = np.convolve(binary, kernel, mode='same')
return rate
def cut_pdf_into_periods(pdf, period, step_size, factor=1.5, use_all=False):
if period / step_size > len(pdf):
return [pdf]
idx_period_length = int(period/float(step_size))
offset_per_step = period/float(step_size) - idx_period_length
cut_length = int(period / float(step_size) * factor)
num_of_cuts = int(len(pdf) / (idx_period_length+offset_per_step))
if len(pdf) - (num_of_cuts * idx_period_length + (num_of_cuts * offset_per_step)) < cut_length - idx_period_length:
num_of_cuts -= 1
if num_of_cuts <= 1:
raise RuntimeError("Probability density function to short to cut.")
cuts = np.zeros((num_of_cuts-1, cut_length))
for i in np.arange(1, num_of_cuts, 1):
offset_correction = int(offset_per_step * i)
start_idx = i*idx_period_length + offset_correction
end_idx = (i*idx_period_length)+cut_length + offset_correction
cut = np.array(pdf[start_idx: end_idx])
cuts[i-1] = cut
if len(cuts.shape) < 2:
print("Fishy....")
return cuts
def gaussian_kernel(sigma, dt):
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
if __name__ == '__main__':
main()