chirp_probing/punit_responses.py

245 lines
12 KiB
Python

import numpy as np
import nixio as nix
import argparse
import os
from numpy.core.fromnumeric import repeat
from traitlets.traitlets import Instance
from chirp_ams import get_signals
from model import simulate, load_models
from IPython import embed
import matplotlib.pyplot as plt
import multiprocessing
from joblib import Parallel, delayed
data_folder = "data"
def append_settings(section, sec_name, sec_type, settings):
section = section.create_section(sec_name, sec_type)
for k in settings.keys():
if isinstance(settings[k], dict):
append_settings(section, k, "settings", settings[k])
else:
if isinstance(settings[k], np.ndarray):
if len(settings[k].shape) == 1:
section[k] = list(settings[k])
else:
section[k] = settings[k]
def save(filename, name, stimulus_settings, model_settings, self_signal, other_signal, self_freq, other_freq, complete_stimulus, responses, overwrite=False):
if os.path.exists(filename) and not overwrite:
nf = nix.File.open(filename, nix.FileMode.ReadWrite)
else:
nf = nix.File.open(filename, mode=nix.FileMode.Overwrite,
compression=nix.Compression.DeflateNormal)
if name in nf.blocks:
print("Data with this name is already stored! ", name)
nf.close()
return
mdata = nf.create_section(name, "nix.simulation")
append_settings(mdata, "model parameter", "nix.model.settings", model_settings)
append_settings(mdata, "stimulus parameter", "nix.stimulus.settings", stimulus_settings)
b = nf.create_block(name, "nix.simulation")
b.metadata = mdata
# save stimulus
stim_da = b.create_data_array("complete stimulus", "nix.timeseries.sampled.stimulus", dtype=nix.DataType.Float,
data=complete_stimulus)
stim_da.label = "voltage"
stim_da.label = "mV/cm"
dim = stim_da.append_sampled_dimension(model_settings["deltat"])
dim.label = "time"
dim.unit = "s"
self_freq_da = None
if self_freq is not None:
self_freq_da = b.create_data_array("self frequency", "nix.timeseries.sampled.frequency", dtype=nix.DataType.Float,
data=self_freq)
self_freq_da.label = "frequency"
self_freq_da.label = "Hz"
dim = self_freq_da.append_sampled_dimension(model_settings["deltat"])
dim.label = "time"
dim.unit = "s"
other_freq_da = None
if other_freq is not None:
other_freq_da = b.create_data_array("other frequency", "nix.timeseries.sampled", dtype=nix.DataType.Float,
data=other_freq)
self_freq_da.label = "frequency"
self_freq_da.label = "Hz"
dim = other_freq_da.append_sampled_dimension(model_settings["deltat"])
dim.label = "time"
dim.unit = "s"
# save responses
for i in range(len(responses)):
da = b.create_data_array("response_%i" %i, "nix.timeseries.events.spike_times",
dtype=nix.DataType.Float, data=responses[i])
da.label = "time"
da.unit = "s"
dim = da.append_alias_range_dimension()
# bind it all together
tag = b.create_tag("chirp stimulation", "nix.stimulus", [0.0])
tag.extent = [stim_da.shape[0] * stim_da.dimensions[0].sampling_interval]
for da in b.data_arrays:
if "response" in da.name.lower():
tag.references.append(da)
tag.create_feature(stim_da, nix.LinkType.Untagged)
if self_freq_da is not None:
tag.create_feature(self_freq_da, nix.LinkType.Untagged)
if other_freq_da is not None:
tag.create_feature(other_freq_da, nix.LinkType.Untagged)
nf.close()
def save_baseline_response(filename, block_name, spikes, model_settings, overwrite=False):
if os.path.exists(filename) and not overwrite:
nf = nix.File.open(filename, nix.FileMode.ReadWrite)
else:
nf = nix.File.open(filename, mode=nix.FileMode.Overwrite,
compression=nix.Compression.DeflateNormal)
if block_name in nf.blocks:
print("Data with this name is already stored! ", block_name)
nf.close()
return
mdata = nf.create_section(block_name, "nix.simulation")
append_settings(mdata, "model parameter", "nix.model.settings", model_settings)
b = nf.create_block(block_name, "nix.simulation")
b.metadata = mdata
da = b.create_data_array("baseline_response", "nix.timeseries.events.spike_times",
dtype=nix.DataType.Float, data=spikes)
da.label = "time"
da.unit = "s"
dim = da.append_alias_range_dimension()
nf.close()
def get_baseline_response(model_params, duration=10):
eodf = model_params["EODf"]
dt = model_params["deltat"]
cell_params = model_params.copy()
del cell_params["cell"]
del cell_params["EODf"]
time, pre_stim = get_pre_stimulus(eodf)
baseline_time = np.arange(0.0, duration, dt)
eod = np.sin(baseline_time * eodf * 2 * np.pi)
stim = np.hstack((pre_stim, eod))
v_0 = np.random.rand(1)[0]
cell_params["v_zero"] = v_0
spikes = simulate(stim, **cell_params)
spikes = spikes[spikes > time[-1]] - time[-1]
return spikes
def get_pre_stimulus(eodf, duration=2, df=20, contrast=20, dt=1./20000):
time = np.arange(0.0, duration, dt)
eod = np.sin(time * eodf * 2 * np.pi)
eod_2 = np.sin(time * (eodf + df) * 2 * np.pi) * contrast/100
eod_2 *= np.hanning(len(eod_2))
stim = eod + eod_2
return time, stim
def simulate_responses(stimulus_params, model_params, repeats=10, deltaf=20):
cell_params = model_params.copy()
cell = cell_params["cell"]
filename = os.path.join(data_folder, "cell_%s.nix" % cell)
del cell_params["cell"]
del cell_params["EODf"]
conditions = ["other", "self"]
chirp_size = stimulus_params["chirp_size"]
pre_time, pre_stim = get_pre_stimulus(stimulus_params["eodfs"]["self"], dt=model_params["deltat"])
for contrast in stimulus_params["contrasts"]:
params = stimulus_params.copy()
params["contrast"] = contrast
del params["contrasts"]
del params["chirp_frequency"]
for condition in conditions:
print("\tcontrast: %s, condition: %s" %(contrast, condition), " "*10, end="\r")
block_name = "contrast_%.3f_condition_%s_deltaf_%i_chirpsize_%i" %(contrast, condition, deltaf, chirp_size)
params["condition"] = condition
time, self_signal, self_freq, other_signal, other_freq = get_signals(**params)
full_signal = (self_signal + other_signal)
spikes = []
no_other_spikes = []
for _ in range(repeats):
v_0 = np.random.rand(1)[0]
cell_params["v_zero"] = v_0
sp = simulate(np.hstack((pre_stim, full_signal)), **cell_params)
spikes.append(sp[sp > pre_time[-1]] - pre_time[-1])
if condition == "self":
v_0 = np.random.rand(1)[0]
cell_params["v_zero"] = v_0
sp = simulate(np.hstack((pre_stim, self_signal)), **cell_params)
no_other_spikes.append(sp[sp > pre_time[-1]] - pre_time[-1])
if condition == "self":
name = "contrast_%.3f_condition_no-other_deltaf_%i_chirpsize_%i" %(contrast, deltaf, chirp_size)
save(filename, name, params, cell_params, self_signal, None, self_freq, None, self_signal, no_other_spikes)
save(filename, block_name, params, cell_params, self_signal, other_signal, self_freq, other_freq, full_signal, spikes)
print("\n")
def simulate_cell(cell_id, models, args):
deltafs = args.deltafs # Hz, difference frequency between self and other
chirp_sizes = args.chirpsizes # Hz, the chirp size, i.e. the frequency excursion
stimulus_params = { "eodfs": {"self": 0.0, "other": 0.0}, # eod frequency in Hz, to be overwritten
"contrasts": args.contrasts,
"chirp_size": 100, # Hz, frequency excursion
"chirp_duration": 0.015, # s, chirp duration
"chirp_amplitude_dip": 0.05, # %, amplitude drop during chirp
"chirp_frequency": 5, # Hz, how often does the fish chirp
"duration": 5., # s, total duration of simulation
"dt": 1, # s, stepsize of the simulation, to be overwritten
}
model_params = models[cell_id]
baseline_spikes = get_baseline_response(model_params, duration=30)
filename = os.path.join(data_folder, "cell_%s.nix" % model_params["cell"])
save_baseline_response(filename, "baseline response", baseline_spikes, model_params)
print("Cell: %s" % model_params["cell"])
for cs in chirp_sizes:
for deltaf in deltafs:
stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf}
stimulus_params["dt"] = model_params["deltat"]
stimulus_params["chirp_size"] = cs
print("\t Deltaf: %i" % deltaf)
chirp_times = np.arange(stimulus_params["chirp_duration"],
stimulus_params["duration"] - stimulus_params["chirp_duration"],
1./stimulus_params["chirp_frequency"])
stimulus_params["chirp_times"] = chirp_times
simulate_responses(stimulus_params, model_params, repeats=args.trials, deltaf=deltaf)
def main():
parser = argparse.ArgumentParser(description="Simulate P-unit responses using the model parameters from the models.csv file. Calling it without any arguments works with the defaults, may need some time.")
parser.add_argument("-n", "--number", type=int, default=20, help="Number of simulated neurons. Randomly chosen from model list. Defaults to 20")
parser.add_argument("-t", "--trials", type=int, default=25, help="Number of stimulus repetitions, trials. Defaults to 25")
parser.add_argument("-dfs", "--deltafs", type=float, nargs="+", default=[-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200],
help="List of difference frequencies. Defaults to [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200]")
parser.add_argument("-cs", "--chirpsizes", type=float, nargs="+", default=[40, 60, 100],
help="List of chirp sizes. Defaults to [40, 60, 100]")
parser.add_argument("-ct", "--contrasts", type=float, nargs="+", default=[20, 10, 5, 2.5, 1.25, 0.625, 0.3125],
help="List of foreign fish contrasts. Defaults to [20, 10, 5, 2.5, 1.25, 0.625, 0.3125].")
parser.add_argument("-o", "--output_folder", type=str, default=data_folder, help="Where to store the data. Defaults to %s"%os.path.join(".", data_folder))
parser.add_argument("-j", "--jobs", type=int, default=max(1, int(np.floor(multiprocessing.cpu_count() * 0.5))), help="Number of parallel processes (simulations) defaults to half of the available cores.")
args = parser.parse_args()
models = load_models("models.csv")
num_models = args.number
if args.number > len(models):
print("INFO: number of cells larger than number of available models. Reset to max number of models.")
num_models = len(models)
indices = list(range(len(models)))
np.random.shuffle(indices)
Parallel(n_jobs=args.jobs)(delayed(simulate_cell)(cell_id, models, args) for cell_id in indices)
if __name__ == "__main__":
main()