import numpy as np import nixio as nix import os from numpy.core.fromnumeric import repeat from traitlets.traitlets import Instance from chirp_ams import get_signals from model import simulate, load_models from IPython import embed import matplotlib.pyplot as plt import multiprocessing from joblib import Parallel, delayed data_folder = "data" def append_settings(section, sec_name, sec_type, settings): section = section.create_section(sec_name, sec_type) for k in settings.keys(): if isinstance(settings[k], dict): append_settings(section, k, "settings", settings[k]) else: if isinstance(settings[k], np.ndarray): if len(settings[k].shape) == 1: section[k] = list(settings[k]) else: section[k] = settings[k] def save(filename, name, stimulus_settings, model_settings, self_signal, other_signal, self_freq, other_freq, complete_stimulus, responses, overwrite=False): if os.path.exists(filename) and not overwrite: nf = nix.File.open(filename, nix.FileMode.ReadWrite) else: nf = nix.File.open(filename, mode=nix.FileMode.Overwrite, compression=nix.Compression.DeflateNormal) if name in nf.blocks: print("Data with this name is already stored! ", name) nf.close() return mdata = nf.create_section(name, "nix.simulation") append_settings(mdata, "model parameter", "nix.model.settings", model_settings) append_settings(mdata, "stimulus parameter", "nix.stimulus.settings", stimulus_settings) b = nf.create_block(name, "nix.simulation") b.metadata = mdata # save stimulus stim_da = b.create_data_array("complete stimulus", "nix.timeseries.sampled.stimulus", dtype=nix.DataType.Float, data=complete_stimulus) stim_da.label = "voltage" stim_da.label = "mV/cm" dim = stim_da.append_sampled_dimension(model_settings["deltat"]) dim.label = "time" dim.unit = "s" self_freq_da = None if self_freq is not None: self_freq_da = b.create_data_array("self frequency", "nix.timeseries.sampled.frequency", dtype=nix.DataType.Float, data=self_freq) self_freq_da.label = "frequency" self_freq_da.label = "Hz" dim = self_freq_da.append_sampled_dimension(model_settings["deltat"]) dim.label = "time" dim.unit = "s" other_freq_da = None if other_freq is not None: other_freq_da = b.create_data_array("other frequency", "nix.timeseries.sampled", dtype=nix.DataType.Float, data=other_freq) self_freq_da.label = "frequency" self_freq_da.label = "Hz" dim = other_freq_da.append_sampled_dimension(model_settings["deltat"]) dim.label = "time" dim.unit = "s" # save responses for i in range(len(responses)): da = b.create_data_array("response_%i" %i, "nix.timeseries.events.spike_times", dtype=nix.DataType.Float, data=responses[i]) da.label = "time" da.unit = "s" dim = da.append_alias_range_dimension() # bind it all together tag = b.create_tag("chirp stimulation", "nix.stimulus", [0.0]) tag.extent = [stim_da.shape[0] * stim_da.dimensions[0].sampling_interval] for da in b.data_arrays: if "response" in da.name.lower(): tag.references.append(da) tag.create_feature(stim_da, nix.LinkType.Untagged) if self_freq_da is not None: tag.create_feature(self_freq_da, nix.LinkType.Untagged) if other_freq_da is not None: tag.create_feature(other_freq_da, nix.LinkType.Untagged) nf.close() def save_baseline_response(filename, block_name, spikes, model_settings, overwrite=False): if os.path.exists(filename) and not overwrite: nf = nix.File.open(filename, nix.FileMode.ReadWrite) else: nf = nix.File.open(filename, mode=nix.FileMode.Overwrite, compression=nix.Compression.DeflateNormal) if block_name in nf.blocks: print("Data with this name is already stored! ", block_name) nf.close() return mdata = nf.create_section(block_name, "nix.simulation") append_settings(mdata, "model parameter", "nix.model.settings", model_settings) b = nf.create_block(block_name, "nix.simulation") b.metadata = mdata da = b.create_data_array("baseline_response", "nix.timeseries.events.spike_times", dtype=nix.DataType.Float, data=spikes) da.label = "time" da.unit = "s" dim = da.append_alias_range_dimension() nf.close() def get_baseline_response(model_params, duration=10): eodf = model_params["EODf"] dt = model_params["deltat"] cell_params = model_params.copy() del cell_params["cell"] del cell_params["EODf"] time, pre_stim = get_pre_stimulus(eodf) baseline_time = np.arange(0.0, duration, dt) eod = np.sin(baseline_time * eodf * 2 * np.pi) stim = np.hstack((pre_stim, eod)) v_0 = np.random.rand(1)[0] cell_params["v_zero"] = v_0 spikes = simulate(stim, **cell_params) spikes = spikes[spikes > time[-1]] - time[-1] return spikes def get_pre_stimulus(eodf, duration=2, df=20, contrast=20, dt=1./20000): time = np.arange(0.0, duration, dt) eod = np.sin(time * eodf * 2 * np.pi) eod_2 = np.sin(time * (eodf + df) * 2 * np.pi) * contrast/100 eod_2 *= np.hanning(len(eod_2)) stim = eod + eod_2 return time, stim def simulate_responses(stimulus_params, model_params, repeats=10, deltaf=20): cell_params = model_params.copy() cell = cell_params["cell"] filename = os.path.join(data_folder, "cell_%s.nix" % cell) del cell_params["cell"] del cell_params["EODf"] conditions = ["other", "self"] pre_time, pre_stim = get_pre_stimulus(stimulus_params["eodfs"]["self"], dt=model_params["deltat"]) for contrast in stimulus_params["contrasts"]: params = stimulus_params.copy() params["contrast"] = contrast del params["contrasts"] del params["chirp_frequency"] for condition in conditions: print("\tcontrast: %s, condition: %s" %(contrast, condition), " "*10, end="\r") block_name = "contrast_%.3f_condition_%s_deltaf_%i" %(contrast, condition, deltaf) params["condition"] = condition time, self_signal, self_freq, other_signal, other_freq = get_signals(**params) full_signal = (self_signal + other_signal) spikes = [] no_other_spikes = [] for _ in range(repeats): v_0 = np.random.rand(1)[0] cell_params["v_zero"] = v_0 sp = simulate(np.hstack((pre_stim, full_signal)), **cell_params) spikes.append(sp[sp > pre_time[-1]] - pre_time[-1]) if condition == "self": v_0 = np.random.rand(1)[0] cell_params["v_zero"] = v_0 sp = simulate(np.hstack((pre_stim, self_signal)), **cell_params) no_other_spikes.append(sp[sp > pre_time[-1]] - pre_time[-1]) if condition == "self": name = "contrast_%.3f_condition_no-other_deltaf_%i" %(contrast, deltaf) save(filename, name, params, cell_params, self_signal, None, self_freq, None, self_signal, no_other_spikes) save(filename, block_name, params, cell_params, self_signal, other_signal, self_freq, other_freq, full_signal, spikes) print("\n") def simulate_cell(cell_id, models): deltafs = [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200] # Hz, difference frequency between self and other stimulus_params = { "eodfs": {"self": 0.0, "other": 0.0}, # eod frequency in Hz, to be overwritten "contrasts": [20, 10, 5, 2.5, 1.25, 0.625, 0.3125], "chirp_size": 100, # Hz, frequency excursion "chirp_duration": 0.015, # s, chirp duration "chirp_amplitude_dip": 0.05, # %, amplitude drop during chirp "chirp_frequency": 5, # Hz, how often does the fish chirp "duration": 5., # s, total duration of simulation "dt": 1, # s, stepsize of the simulation, to be overwritten } model_params = models[cell_id] baseline_spikes = get_baseline_response(model_params, duration=30) filename = os.path.join(data_folder, "cell_%s.nix" % model_params["cell"]) save_baseline_response(filename, "baseline response", baseline_spikes, model_params) print("Cell: %s" % model_params["cell"]) for deltaf in deltafs: stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf} stimulus_params["dt"] = model_params["deltat"] print("\t Deltaf: %i" % deltaf) chirp_times = np.arange(stimulus_params["chirp_duration"], stimulus_params["duration"] - stimulus_params["chirp_duration"], 1./stimulus_params["chirp_frequency"]) stimulus_params["chirp_times"] = chirp_times simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf) def main(): models = load_models("models.csv") num_cores = multiprocessing.cpu_count() - 6 Parallel(n_jobs=num_cores)(delayed(simulate_cell)(cell_id, models) for cell_id in range(len(models[:10]))) if __name__ == "__main__": main()