diff --git a/oephys2nix/stimulus_recreation.py b/oephys2nix/stimulus_recreation.py new file mode 100644 index 0000000..956d0bf --- /dev/null +++ b/oephys2nix/stimulus_recreation.py @@ -0,0 +1,522 @@ +import logging +import pathlib +import sys +import tomllib + +import matplotlib.pyplot as plt +import nixio +import numpy as np +import rlxnix as rlx +from IPython import embed +from neo.io import OpenEphysBinaryIO +from nixio.exceptions import DuplicateName +from rich.console import Console +from rich.table import Table +from rlxnix.plugins.efish.utils import extract_am +from scipy import signal + +from oephys2nix.logging import setup_logging +from oephys2nix.metadata import create_dict_from_section, create_metadata_from_dict + +log = logging.getLogger(__name__) +setup_logging(log, level="DEBUG") +console = Console() + + +class StimulusToNix: + def __init__(self, open_ephys_path: pathlib.Path, relacs_nix_path: str, nix_file: str): + self.dataset = rlx.Dataset(relacs_nix_path) + + self.relacs_nix_file = nixio.File.open(relacs_nix_path, nixio.FileMode.ReadOnly) + self.relacs_block = self.relacs_nix_file.blocks[0] + self.relacs_sections = self.relacs_nix_file.sections + + self.neo_data = OpenEphysBinaryIO(open_ephys_path).read(lazy=True) + self.fs = self.neo_data[0].segments[0].analogsignals[0].sampling_rate.magnitude + + self.nix_file_path = nix_file + self.nix_file = nixio.File(nix_file, nixio.FileMode.ReadWrite) + self.block = self.nix_file.blocks[0] + + # Threshold for TTL peak + # constants + self.threshold = 2 + self.new_start_jiggle = 0.1 + + def _append_relacs_tag_mtags(self): + for t in self.relacs_block.tags: + log.debug(f"Appending relacs tags {t.name}") + tag = self.block.create_tag(f"relacs_{t.name}", t.type, position=t.position) + tag.extent = t.extent + tag.references.extend(self.block.groups["relacs"].data_arrays) + + sec = self.relacs_sections[f"{t.name}"] + d = create_dict_from_section(sec) + + try: + new_sec = self.nix_file.create_section(sec.name, sec.type) + create_metadata_from_dict(d, new_sec) + except DuplicateName: + pass + + tag.metadata = self.nix_file.sections[sec.name] + + for t in self.relacs_block.multi_tags: + log.debug(f"Appending relacs multi-tags {t.name}") + mtag = self.block.create_multi_tag( + f"relacs_{t.name}", + t.type, + positions=t.positions[:], + extents=t.extents[:], + ) + mtag.references.extend(self.block.groups["relacs"].data_arrays) + + sec = self.relacs_sections[f"{t.name}"] + d = create_dict_from_section(sec) + try: + new_sec = self.nix_file.create_section(sec.name, sec.type) + create_metadata_from_dict(d, new_sec) + except DuplicateName: + pass + mtag.metadata = self.nix_file.sections[sec.name] + + def _find_peak_ttl(self, time_ttl, peaks_ttl, lower, upper): + peak = time_ttl[peaks_ttl[(time_ttl[peaks_ttl] > lower) & (time_ttl[peaks_ttl] < upper)]] + + if not peak.size > 0: + log.error("No peaks found") + elif peak.size > 1: + if np.all(np.diff(peak) > 0.5): + log.error("Peaks are further aways than 0.5 seconds") + log.error(f"Peaks {peak}, Furthest aways: {np.max(peak)}") + peak = np.mean(peak) + return peak + + def _find_peak_ttl_index(self, time_ttl, peaks_ttl, current_position): + new_repro_start_index = peaks_ttl[ + (time_ttl[peaks_ttl] > current_position - self.new_start_jiggle) + & (time_ttl[peaks_ttl] < current_position + self.new_start_jiggle) + ] + + if new_repro_start_index.size > 1: + log.warning("Multiple current positions taking the last index") + new_repro_start_index = new_repro_start_index[-1] + + if np.where(new_repro_start_index == peaks_ttl)[0] + 1 == len(peaks_ttl): + return np.array([]) + else: + next_repro_start = peaks_ttl[np.where(new_repro_start_index == peaks_ttl)[0] + 1] + start_next_repro = time_ttl[next_repro_start] + log.debug(f"Start of new repro/trial {start_next_repro}") + return start_next_repro + + @property + def _reference_groups(self) -> list[nixio.Group]: + return [ + self.block.groups["neuronal-data"], + self.block.groups["efish"], + self.block.groups["relacs"], + ] + + def _append_mtag(self, repro, positions, extents): + try: + nix_mtag = self.block.create_multi_tag( + f"{repro.name}", + "relacs.stimulus", + positions=positions, + extents=extents, + ) + except DuplicateName: + del self.block.multi_tags[repro.name] + nix_mtag = self.block.create_multi_tag( + f"{repro.name}", + "relacs.stimulus", + positions=positions, + extents=extents, + ) + sec = self.relacs_sections[f"{repro.name}"] + d = create_dict_from_section(sec) + + try: + new_sec = self.nix_file.create_section(sec.name, sec.type) + create_metadata_from_dict(d, new_sec) + except DuplicateName: + del self.nix_file.sections[sec.name] + new_sec = self.nix_file.create_section(sec.name, sec.type) + create_metadata_from_dict(d, new_sec) + + nix_mtag.metadata = self.nix_file.sections[repro.name] + + [ + nix_mtag.references.extend(ref_groups.data_arrays) + for ref_groups in self._reference_groups + if ref_groups.data_arrays + ] + + try: + nix_group = self.block.create_group(repro.name, repro.type) + except DuplicateName: + nix_group = self.nix_file.blocks[0].groups[repro.name] + nix_group.multi_tags.append(nix_mtag) + + def _append_tag(self, repro, position, extent): + try: + nix_tag = self.block.create_tag( + f"{repro.name}", + "relacs.repro_run", + position=np.array(position).flatten(), + ) + nix_tag.extent = extent.flatten() + except DuplicateName: + del self.block.tags[repro.name] + nix_tag = self.block.create_tag( + f"{repro.name}", + "relacs.repro_run", + position=np.array(position).flatten(), + ) + nix_tag.extent = extent.flatten() + + sec = self.relacs_sections[f"{repro.name}"] + d = create_dict_from_section(sec) + + try: + new_sec = self.nix_file.create_section(sec.name, sec.type) + create_metadata_from_dict(d, new_sec) + except DuplicateName: + del self.nix_file.sections[sec.name] + new_sec = self.nix_file.create_section(sec.name, sec.type) + create_metadata_from_dict(d, new_sec) + + nix_tag.metadata = self.nix_file.sections[repro.name] + + # NOTE: adding refs to tag + [ + nix_tag.references.extend(ref_groups.data_arrays) + for ref_groups in self._reference_groups + if ref_groups.data_arrays + ] + try: + nix_group = self.block.create_group(repro.name, repro.type) + except DuplicateName: + nix_group = self.nix_file.blocks[0].groups[repro.name] + nix_group.tags.append(nix_tag) + + def create_repros_automatically(self): + ttl_oeph = self.block.data_arrays["ttl-line"][:] + time_ttl = np.arange(len(ttl_oeph)) / self.fs + time_index = np.arange(len(ttl_oeph)) + + peaks_ttl = time_index[ + (np.roll(ttl_oeph, 1) < self.threshold) & (ttl_oeph > self.threshold) + ] + # WARNING:Check if peaks are duplicates or near each other + close_peaks = np.where(np.diff(peaks_ttl) == 1)[0] + if close_peaks.size > 0: + peaks_ttl = np.delete(peaks_ttl, close_peaks) + + first_peak = self._find_peak_ttl( + time_ttl, + peaks_ttl, + time_ttl[peaks_ttl[0]] - self.new_start_jiggle, + time_ttl[peaks_ttl[0]] + self.new_start_jiggle, + ) + + current_position = np.asarray(first_peak.reshape(1)) + for i, repro in enumerate(self.dataset.repro_runs()): + log.debug(repro.name) + log.debug(f"Current Position {current_position.item()}") + if repro.duration < 1.0: + log.warning(f"Skipping repro {repro.name} because it is two short") + continue + + if repro.stimuli: + log.debug("Processing MultiTag") + repetition = len(repro.stimuli) + extents_mtag = np.zeros((repetition, 1)) + position_mtags = np.zeros((repetition, 1)) + + for trial, stimulus in enumerate(repro.stimuli): + duration_trial = stimulus.duration + extents_mtag[trial] = duration_trial + position_mtags[trial] = current_position + + current_position = self._find_peak_ttl_index( + time_ttl, peaks_ttl, current_position + ) + + # current_position = position_mtags[-1] + self._append_mtag(repro, position_mtags, extents_mtag) + extent = position_mtags[-1] + extents_mtag[-1] - position_mtags[0] + self._append_tag(repro, position_mtags[0], extent) + + if not current_position.size > 0: + log.debug("Checking if it is the last repro") + if i != len(self.dataset.repro_runs()) - 1: + if "Baseline" in self.dataset.repro_runs()[-1].name: + log.debug("Appending last baseline Repro") + last_position = ( + self.block.tags[repro.name].position[0] + + self.block.tags[repro.name].extent[0] + ) + lastrepro = self.dataset.repro_runs()[-1] + self._append_tag(lastrepro, last_position, lastrepro.duration) + else: + log.error("Last Repro was not appended") + break + + log.info("Finishing writing") + log.info("Closing nix files") + break + else: + if i == 0 and "BaselineActivity" in repro.name: + self._append_tag(repro, 0.0, current_position) + continue + + last_repro_name = self.dataset.repro_runs()[i - 1].name + last_repro_position = ( + self.block.groups[last_repro_name].tags[0].position[0] + + self.block.groups[last_repro_name].tags[0].extent[0] + ) + self._append_tag( + repro, + last_repro_position.reshape(-1, 1), + (current_position - last_repro_position).reshape(-1, 1), + ) + # self.close() + + def create_repros_from_config_file(self): + ttl_oeph = self.block.data_arrays["ttl-line"][:] + peaks_ttl = signal.find_peaks( + ttl_oeph.flatten(), + **self.stimulus_config["stimulus"]["peak_detection"], + )[0] + time_ttl = np.arange(len(ttl_oeph)) / self.cfg.open_ephys.samplerate + + number_of_repros = len(self.stimulus_config["repros"]["name"]) + + referencs_groups = [ + self.block.groups["neuronal-data"], + self.block.groups["spike-data"], + self.block.groups["efish"], + ] + + for repro_index in range(number_of_repros): + name = self.stimulus_config["repros"]["name"][repro_index] + log.debug(name) + start = np.array(self.stimulus_config["repros"]["start"][repro_index]) + end = np.array(self.stimulus_config["repros"]["end"][repro_index]) + nix_group = self.block.groups[name] + if start.size > 1: + start_repro = time_ttl[ + peaks_ttl[(time_ttl[peaks_ttl] > start[0]) & (time_ttl[peaks_ttl] < start[1])] + ] + if start_repro.size > 1: + if np.all(np.diff(start_repro) > 0.005): + log.error("Wrong end point in end of repro") + log.error(f"{name[repro_index]}, {np.max(start_repro)}") + exit(1) + start_repro = np.mean(start_repro) + else: + start_repro = start_repro[0] + else: + start_repro = start[0] + + if end.size > 1: + end_repro = time_ttl[ + peaks_ttl[(time_ttl[peaks_ttl] > end[0]) & (time_ttl[peaks_ttl] < end[1])] + ] + if end_repro.size > 1: + if np.all(np.diff(end_repro) > 0.005): + log.error("Wrong end point in end of repro") + log.error(f"{name[repro_index]}, {np.max(end_repro)}") + exit(1) + end_repro = np.mean(end_repro) + else: + end_repro = end[0] + + nix_tag = self.block.create_tag( + f"{nix_group.name}", + f"{nix_group.type}", + position=[start_repro], + ) + nix_tag.extent = [end_repro - start_repro] + nix_tag.metadata = self.nix_file.sections[name] + + # NOTE: adding refs to tag + [ + nix_tag.references.extend(ref_groups.data_arrays) + for ref_groups in referencs_groups + if ref_groups.data_arrays + ] + nix_group.tags.append(nix_tag) + + if not self.relacs_block.groups[name].multi_tags: + log.debug(f"no multitags in repro {name}, skipping") + continue + + start_repro_multi_tag = start_repro.copy() + positions = [] + positions.append(np.array(start_repro)) + extents = [] + extents_relacsed_nix = self.relacs_block.groups[name].multi_tags[0].extents[:] + extents.append(np.array(extents_relacsed_nix[0])) + index_multi_tag = 0 + while start_repro_multi_tag < end_repro: + log.debug(f"{start_repro_multi_tag}") + + start_repro_multi_tag = time_ttl[ + peaks_ttl[ + ( + time_ttl[peaks_ttl] + > start_repro_multi_tag + extents_relacsed_nix[index_multi_tag] + ) + & ( + time_ttl[peaks_ttl] + < start_repro_multi_tag + extents_relacsed_nix[index_multi_tag] + 2 + ) + ] + ] + + if start_repro_multi_tag.size == 0: + log.debug("Did not find any peaks for new start multi tag") + log.debug( + f"Differenz to end of repro {end_repro - (positions[-1] + extents_relacsed_nix[index_multi_tag])}" + ) + break + + if start_repro_multi_tag.size > 1: + if np.all(np.diff(start_repro_multi_tag) > 0.005): + log.error("Wrong end point in end of repro") + log.error(f"Repro_name: {name[repro_index]}") + log.error(f"multitag_index: {index_multi_tag}") + log.error(f"max_value {np.max(start_repro_multi_tag)}") + sys.exit(1) + start_repro_multi_tag = np.mean(start_repro_multi_tag) + else: + start_repro_multi_tag = start_repro_multi_tag[0] + if end_repro - start_repro_multi_tag < 1: + log.debug("Posssible endpoint detected") + log.debug(f"Differenz to repro end: {end_repro - start_repro_multi_tag}") + break + + positions.append(np.array(start_repro_multi_tag)) + extents.append(extents_relacsed_nix[index_multi_tag]) + + positions = np.array(positions).reshape(len(positions), 1) + extents = np.array(extents) + if positions.size != extents.size: + log.error("Calcualted positions and extents do not match ") + log.error(f"Shape positions {positions.shape}, shape extents {extents.shape}") + if positions.shape != extents.shape: + log.error("Shape of Calculated positions and extents do not match") + log.error(f"Shape positions {positions.shape}, shape extents {extents.shape}") + embed() + sys.exit(1) + + nix_mtag = self.block.create_multi_tag( + f"{nix_group.name}", + f"{nix_group.type}", + positions=positions, + extents=extents, + ) + nix_mtag.metadata = self.nix_file.sections[name] + + # NOTE: adding refs to mtag + [ + nix_mtag.references.extend(ref_groups.data_arrays) + for ref_groups in referencs_groups + if ref_groups.data_arrays + ] + nix_group.multi_tags.append(nix_mtag) + + def print_table(self): + nix_data_set = rlx.Dataset(self.nix_file_path) + table = Table("Repro Name", "start", "stop", "duration") + for repro_r, repro_n in zip(self.dataset.repro_runs(), nix_data_set.repro_runs()): + table.add_row( + f"{repro_r.name} -> [red]{repro_n.name}[/red]", + f"{repro_r.start_time:.2f} -> [red]{repro_n.start_time:.2f}[/red]", + f"{repro_r.stop_time:.2f} -> [red]{repro_n.stop_time:.2f}[/red]", + f"{repro_r.duration:.2f} -> [red]{repro_n.duration:.2f}[/red]", + ) + + console.print(table) + nix_data_set.close() + + def checks(self): + important_repros = ["FileStimulus", "SAM", "FICurve"] + nix_data_set = rlx.Dataset(self.nix_file_path) + + for repro_r, repro_n in zip(self.dataset.repro_runs(), nix_data_set.repro_runs()): + if repro_n.name.split("_")[0] not in important_repros: + continue + if "SAM" in repro_n.name: + embed() + exit() + else: + continue + + if "FileStimulus" in repro_n.name: + repro_n.stimulus_folder = "/home/alexander/stimuli/whitenoise/" + repro_r.stimulus_folder = "/home/alexander/stimuli/whitenoise/" + white_noise, stimulus_time = repro_n.load_stimulus(0, 30_000) + white_noise_r, stimulus_time_r = repro_n.load_stimulus(0, 20_000) + + stim = repro_n.stimuli[0] + stim_r = repro_r.stimuli[0] + v, t = stim.trace_data("local-eod") + v_r, t_r = stim_r.trace_data("LocalEOD-1") + plt.plot(t_r, v_r, "o--", color="tab:blue", label="relacs") + plt.plot(t, v, "o--", color="tab:red", label="open-ephys") + # plt.plot(stimulus_time, white_noise, "o--", color="tab:orange", label="stimulus") + # plt.plot(new_t, new_v, "go--", label="resampled") + for i, (s, r) in enumerate(zip(repro_n.stimuli, repro_r.stimuli)): + v, t = s.trace_data("local-eod") + v_r, t_r = r.trace_data("LocalEOD-1") + plt.plot(t_r + (i * r.duration), v_r) + plt.plot(t + (i * s.duration), v) + plt.legend(loc="upper right") + plt.show() + + am = extract_am(v, 30_000) + am_r = extract_am(v_r, 20_000) + + plt.plot(t, am - np.mean(am) / np.max(np.abs(am))) + plt.plot(t_r, am_r - np.mean(am_r) / np.max(np.abs(am_r))) + v_eod, t_eod = repro_n.trace_data("global-eod") + v_eodr, t_eodr = repro_n.trace_data("EOD") + + def plot_stimulus(self): + ttl_oeph = self.block.data_arrays["ttl-line"][:] + + time_index = np.arange(len(ttl_oeph)) + peaks_ttl = time_index[(np.roll(ttl_oeph, 1) < 2) & (ttl_oeph > 2)] + + stimulus_oeph = self.block.data_arrays["stimulus"] + stimulus = self.relacs_block.data_arrays["GlobalEFieldStimulus"] + + plt.plot(np.arange(stimulus.size) / 20_000.0, stimulus, label="relacs-stimulus") + plt.plot( + np.arange(len(stimulus_oeph)) / 30_000.0, + stimulus_oeph[:], + label="open-ephys", + ) + plt.plot( + np.arange(len(ttl_oeph)) / 30_000.0, + ttl_oeph[:], + label="ttl-line", + ) + plt.scatter( + np.arange(len(ttl_oeph))[peaks_ttl] / 30_000.0, + ttl_oeph[peaks_ttl], + label="detected peaks", + color="red", + zorder=100, + ) + plt.legend(loc="upper right") + plt.show() + + def close(self): + self.dataset.close() + self.nix_file.close() + self.relacs_nix_file.close()