[stimulus_recreation] adding script for recreating stimulus

This commit is contained in:
wendtalexander 2025-10-07 16:07:18 +02:00
parent 1af6fa4f7b
commit e61172227e

View File

@ -0,0 +1,522 @@
import logging
import pathlib
import sys
import tomllib
import matplotlib.pyplot as plt
import nixio
import numpy as np
import rlxnix as rlx
from IPython import embed
from neo.io import OpenEphysBinaryIO
from nixio.exceptions import DuplicateName
from rich.console import Console
from rich.table import Table
from rlxnix.plugins.efish.utils import extract_am
from scipy import signal
from oephys2nix.logging import setup_logging
from oephys2nix.metadata import create_dict_from_section, create_metadata_from_dict
log = logging.getLogger(__name__)
setup_logging(log, level="DEBUG")
console = Console()
class StimulusToNix:
def __init__(self, open_ephys_path: pathlib.Path, relacs_nix_path: str, nix_file: str):
self.dataset = rlx.Dataset(relacs_nix_path)
self.relacs_nix_file = nixio.File.open(relacs_nix_path, nixio.FileMode.ReadOnly)
self.relacs_block = self.relacs_nix_file.blocks[0]
self.relacs_sections = self.relacs_nix_file.sections
self.neo_data = OpenEphysBinaryIO(open_ephys_path).read(lazy=True)
self.fs = self.neo_data[0].segments[0].analogsignals[0].sampling_rate.magnitude
self.nix_file_path = nix_file
self.nix_file = nixio.File(nix_file, nixio.FileMode.ReadWrite)
self.block = self.nix_file.blocks[0]
# Threshold for TTL peak
# constants
self.threshold = 2
self.new_start_jiggle = 0.1
def _append_relacs_tag_mtags(self):
for t in self.relacs_block.tags:
log.debug(f"Appending relacs tags {t.name}")
tag = self.block.create_tag(f"relacs_{t.name}", t.type, position=t.position)
tag.extent = t.extent
tag.references.extend(self.block.groups["relacs"].data_arrays)
sec = self.relacs_sections[f"{t.name}"]
d = create_dict_from_section(sec)
try:
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
except DuplicateName:
pass
tag.metadata = self.nix_file.sections[sec.name]
for t in self.relacs_block.multi_tags:
log.debug(f"Appending relacs multi-tags {t.name}")
mtag = self.block.create_multi_tag(
f"relacs_{t.name}",
t.type,
positions=t.positions[:],
extents=t.extents[:],
)
mtag.references.extend(self.block.groups["relacs"].data_arrays)
sec = self.relacs_sections[f"{t.name}"]
d = create_dict_from_section(sec)
try:
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
except DuplicateName:
pass
mtag.metadata = self.nix_file.sections[sec.name]
def _find_peak_ttl(self, time_ttl, peaks_ttl, lower, upper):
peak = time_ttl[peaks_ttl[(time_ttl[peaks_ttl] > lower) & (time_ttl[peaks_ttl] < upper)]]
if not peak.size > 0:
log.error("No peaks found")
elif peak.size > 1:
if np.all(np.diff(peak) > 0.5):
log.error("Peaks are further aways than 0.5 seconds")
log.error(f"Peaks {peak}, Furthest aways: {np.max(peak)}")
peak = np.mean(peak)
return peak
def _find_peak_ttl_index(self, time_ttl, peaks_ttl, current_position):
new_repro_start_index = peaks_ttl[
(time_ttl[peaks_ttl] > current_position - self.new_start_jiggle)
& (time_ttl[peaks_ttl] < current_position + self.new_start_jiggle)
]
if new_repro_start_index.size > 1:
log.warning("Multiple current positions taking the last index")
new_repro_start_index = new_repro_start_index[-1]
if np.where(new_repro_start_index == peaks_ttl)[0] + 1 == len(peaks_ttl):
return np.array([])
else:
next_repro_start = peaks_ttl[np.where(new_repro_start_index == peaks_ttl)[0] + 1]
start_next_repro = time_ttl[next_repro_start]
log.debug(f"Start of new repro/trial {start_next_repro}")
return start_next_repro
@property
def _reference_groups(self) -> list[nixio.Group]:
return [
self.block.groups["neuronal-data"],
self.block.groups["efish"],
self.block.groups["relacs"],
]
def _append_mtag(self, repro, positions, extents):
try:
nix_mtag = self.block.create_multi_tag(
f"{repro.name}",
"relacs.stimulus",
positions=positions,
extents=extents,
)
except DuplicateName:
del self.block.multi_tags[repro.name]
nix_mtag = self.block.create_multi_tag(
f"{repro.name}",
"relacs.stimulus",
positions=positions,
extents=extents,
)
sec = self.relacs_sections[f"{repro.name}"]
d = create_dict_from_section(sec)
try:
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
except DuplicateName:
del self.nix_file.sections[sec.name]
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
nix_mtag.metadata = self.nix_file.sections[repro.name]
[
nix_mtag.references.extend(ref_groups.data_arrays)
for ref_groups in self._reference_groups
if ref_groups.data_arrays
]
try:
nix_group = self.block.create_group(repro.name, repro.type)
except DuplicateName:
nix_group = self.nix_file.blocks[0].groups[repro.name]
nix_group.multi_tags.append(nix_mtag)
def _append_tag(self, repro, position, extent):
try:
nix_tag = self.block.create_tag(
f"{repro.name}",
"relacs.repro_run",
position=np.array(position).flatten(),
)
nix_tag.extent = extent.flatten()
except DuplicateName:
del self.block.tags[repro.name]
nix_tag = self.block.create_tag(
f"{repro.name}",
"relacs.repro_run",
position=np.array(position).flatten(),
)
nix_tag.extent = extent.flatten()
sec = self.relacs_sections[f"{repro.name}"]
d = create_dict_from_section(sec)
try:
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
except DuplicateName:
del self.nix_file.sections[sec.name]
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
nix_tag.metadata = self.nix_file.sections[repro.name]
# NOTE: adding refs to tag
[
nix_tag.references.extend(ref_groups.data_arrays)
for ref_groups in self._reference_groups
if ref_groups.data_arrays
]
try:
nix_group = self.block.create_group(repro.name, repro.type)
except DuplicateName:
nix_group = self.nix_file.blocks[0].groups[repro.name]
nix_group.tags.append(nix_tag)
def create_repros_automatically(self):
ttl_oeph = self.block.data_arrays["ttl-line"][:]
time_ttl = np.arange(len(ttl_oeph)) / self.fs
time_index = np.arange(len(ttl_oeph))
peaks_ttl = time_index[
(np.roll(ttl_oeph, 1) < self.threshold) & (ttl_oeph > self.threshold)
]
# WARNING:Check if peaks are duplicates or near each other
close_peaks = np.where(np.diff(peaks_ttl) == 1)[0]
if close_peaks.size > 0:
peaks_ttl = np.delete(peaks_ttl, close_peaks)
first_peak = self._find_peak_ttl(
time_ttl,
peaks_ttl,
time_ttl[peaks_ttl[0]] - self.new_start_jiggle,
time_ttl[peaks_ttl[0]] + self.new_start_jiggle,
)
current_position = np.asarray(first_peak.reshape(1))
for i, repro in enumerate(self.dataset.repro_runs()):
log.debug(repro.name)
log.debug(f"Current Position {current_position.item()}")
if repro.duration < 1.0:
log.warning(f"Skipping repro {repro.name} because it is two short")
continue
if repro.stimuli:
log.debug("Processing MultiTag")
repetition = len(repro.stimuli)
extents_mtag = np.zeros((repetition, 1))
position_mtags = np.zeros((repetition, 1))
for trial, stimulus in enumerate(repro.stimuli):
duration_trial = stimulus.duration
extents_mtag[trial] = duration_trial
position_mtags[trial] = current_position
current_position = self._find_peak_ttl_index(
time_ttl, peaks_ttl, current_position
)
# current_position = position_mtags[-1]
self._append_mtag(repro, position_mtags, extents_mtag)
extent = position_mtags[-1] + extents_mtag[-1] - position_mtags[0]
self._append_tag(repro, position_mtags[0], extent)
if not current_position.size > 0:
log.debug("Checking if it is the last repro")
if i != len(self.dataset.repro_runs()) - 1:
if "Baseline" in self.dataset.repro_runs()[-1].name:
log.debug("Appending last baseline Repro")
last_position = (
self.block.tags[repro.name].position[0]
+ self.block.tags[repro.name].extent[0]
)
lastrepro = self.dataset.repro_runs()[-1]
self._append_tag(lastrepro, last_position, lastrepro.duration)
else:
log.error("Last Repro was not appended")
break
log.info("Finishing writing")
log.info("Closing nix files")
break
else:
if i == 0 and "BaselineActivity" in repro.name:
self._append_tag(repro, 0.0, current_position)
continue
last_repro_name = self.dataset.repro_runs()[i - 1].name
last_repro_position = (
self.block.groups[last_repro_name].tags[0].position[0]
+ self.block.groups[last_repro_name].tags[0].extent[0]
)
self._append_tag(
repro,
last_repro_position.reshape(-1, 1),
(current_position - last_repro_position).reshape(-1, 1),
)
# self.close()
def create_repros_from_config_file(self):
ttl_oeph = self.block.data_arrays["ttl-line"][:]
peaks_ttl = signal.find_peaks(
ttl_oeph.flatten(),
**self.stimulus_config["stimulus"]["peak_detection"],
)[0]
time_ttl = np.arange(len(ttl_oeph)) / self.cfg.open_ephys.samplerate
number_of_repros = len(self.stimulus_config["repros"]["name"])
referencs_groups = [
self.block.groups["neuronal-data"],
self.block.groups["spike-data"],
self.block.groups["efish"],
]
for repro_index in range(number_of_repros):
name = self.stimulus_config["repros"]["name"][repro_index]
log.debug(name)
start = np.array(self.stimulus_config["repros"]["start"][repro_index])
end = np.array(self.stimulus_config["repros"]["end"][repro_index])
nix_group = self.block.groups[name]
if start.size > 1:
start_repro = time_ttl[
peaks_ttl[(time_ttl[peaks_ttl] > start[0]) & (time_ttl[peaks_ttl] < start[1])]
]
if start_repro.size > 1:
if np.all(np.diff(start_repro) > 0.005):
log.error("Wrong end point in end of repro")
log.error(f"{name[repro_index]}, {np.max(start_repro)}")
exit(1)
start_repro = np.mean(start_repro)
else:
start_repro = start_repro[0]
else:
start_repro = start[0]
if end.size > 1:
end_repro = time_ttl[
peaks_ttl[(time_ttl[peaks_ttl] > end[0]) & (time_ttl[peaks_ttl] < end[1])]
]
if end_repro.size > 1:
if np.all(np.diff(end_repro) > 0.005):
log.error("Wrong end point in end of repro")
log.error(f"{name[repro_index]}, {np.max(end_repro)}")
exit(1)
end_repro = np.mean(end_repro)
else:
end_repro = end[0]
nix_tag = self.block.create_tag(
f"{nix_group.name}",
f"{nix_group.type}",
position=[start_repro],
)
nix_tag.extent = [end_repro - start_repro]
nix_tag.metadata = self.nix_file.sections[name]
# NOTE: adding refs to tag
[
nix_tag.references.extend(ref_groups.data_arrays)
for ref_groups in referencs_groups
if ref_groups.data_arrays
]
nix_group.tags.append(nix_tag)
if not self.relacs_block.groups[name].multi_tags:
log.debug(f"no multitags in repro {name}, skipping")
continue
start_repro_multi_tag = start_repro.copy()
positions = []
positions.append(np.array(start_repro))
extents = []
extents_relacsed_nix = self.relacs_block.groups[name].multi_tags[0].extents[:]
extents.append(np.array(extents_relacsed_nix[0]))
index_multi_tag = 0
while start_repro_multi_tag < end_repro:
log.debug(f"{start_repro_multi_tag}")
start_repro_multi_tag = time_ttl[
peaks_ttl[
(
time_ttl[peaks_ttl]
> start_repro_multi_tag + extents_relacsed_nix[index_multi_tag]
)
& (
time_ttl[peaks_ttl]
< start_repro_multi_tag + extents_relacsed_nix[index_multi_tag] + 2
)
]
]
if start_repro_multi_tag.size == 0:
log.debug("Did not find any peaks for new start multi tag")
log.debug(
f"Differenz to end of repro {end_repro - (positions[-1] + extents_relacsed_nix[index_multi_tag])}"
)
break
if start_repro_multi_tag.size > 1:
if np.all(np.diff(start_repro_multi_tag) > 0.005):
log.error("Wrong end point in end of repro")
log.error(f"Repro_name: {name[repro_index]}")
log.error(f"multitag_index: {index_multi_tag}")
log.error(f"max_value {np.max(start_repro_multi_tag)}")
sys.exit(1)
start_repro_multi_tag = np.mean(start_repro_multi_tag)
else:
start_repro_multi_tag = start_repro_multi_tag[0]
if end_repro - start_repro_multi_tag < 1:
log.debug("Posssible endpoint detected")
log.debug(f"Differenz to repro end: {end_repro - start_repro_multi_tag}")
break
positions.append(np.array(start_repro_multi_tag))
extents.append(extents_relacsed_nix[index_multi_tag])
positions = np.array(positions).reshape(len(positions), 1)
extents = np.array(extents)
if positions.size != extents.size:
log.error("Calcualted positions and extents do not match ")
log.error(f"Shape positions {positions.shape}, shape extents {extents.shape}")
if positions.shape != extents.shape:
log.error("Shape of Calculated positions and extents do not match")
log.error(f"Shape positions {positions.shape}, shape extents {extents.shape}")
embed()
sys.exit(1)
nix_mtag = self.block.create_multi_tag(
f"{nix_group.name}",
f"{nix_group.type}",
positions=positions,
extents=extents,
)
nix_mtag.metadata = self.nix_file.sections[name]
# NOTE: adding refs to mtag
[
nix_mtag.references.extend(ref_groups.data_arrays)
for ref_groups in referencs_groups
if ref_groups.data_arrays
]
nix_group.multi_tags.append(nix_mtag)
def print_table(self):
nix_data_set = rlx.Dataset(self.nix_file_path)
table = Table("Repro Name", "start", "stop", "duration")
for repro_r, repro_n in zip(self.dataset.repro_runs(), nix_data_set.repro_runs()):
table.add_row(
f"{repro_r.name} -> [red]{repro_n.name}[/red]",
f"{repro_r.start_time:.2f} -> [red]{repro_n.start_time:.2f}[/red]",
f"{repro_r.stop_time:.2f} -> [red]{repro_n.stop_time:.2f}[/red]",
f"{repro_r.duration:.2f} -> [red]{repro_n.duration:.2f}[/red]",
)
console.print(table)
nix_data_set.close()
def checks(self):
important_repros = ["FileStimulus", "SAM", "FICurve"]
nix_data_set = rlx.Dataset(self.nix_file_path)
for repro_r, repro_n in zip(self.dataset.repro_runs(), nix_data_set.repro_runs()):
if repro_n.name.split("_")[0] not in important_repros:
continue
if "SAM" in repro_n.name:
embed()
exit()
else:
continue
if "FileStimulus" in repro_n.name:
repro_n.stimulus_folder = "/home/alexander/stimuli/whitenoise/"
repro_r.stimulus_folder = "/home/alexander/stimuli/whitenoise/"
white_noise, stimulus_time = repro_n.load_stimulus(0, 30_000)
white_noise_r, stimulus_time_r = repro_n.load_stimulus(0, 20_000)
stim = repro_n.stimuli[0]
stim_r = repro_r.stimuli[0]
v, t = stim.trace_data("local-eod")
v_r, t_r = stim_r.trace_data("LocalEOD-1")
plt.plot(t_r, v_r, "o--", color="tab:blue", label="relacs")
plt.plot(t, v, "o--", color="tab:red", label="open-ephys")
# plt.plot(stimulus_time, white_noise, "o--", color="tab:orange", label="stimulus")
# plt.plot(new_t, new_v, "go--", label="resampled")
for i, (s, r) in enumerate(zip(repro_n.stimuli, repro_r.stimuli)):
v, t = s.trace_data("local-eod")
v_r, t_r = r.trace_data("LocalEOD-1")
plt.plot(t_r + (i * r.duration), v_r)
plt.plot(t + (i * s.duration), v)
plt.legend(loc="upper right")
plt.show()
am = extract_am(v, 30_000)
am_r = extract_am(v_r, 20_000)
plt.plot(t, am - np.mean(am) / np.max(np.abs(am)))
plt.plot(t_r, am_r - np.mean(am_r) / np.max(np.abs(am_r)))
v_eod, t_eod = repro_n.trace_data("global-eod")
v_eodr, t_eodr = repro_n.trace_data("EOD")
def plot_stimulus(self):
ttl_oeph = self.block.data_arrays["ttl-line"][:]
time_index = np.arange(len(ttl_oeph))
peaks_ttl = time_index[(np.roll(ttl_oeph, 1) < 2) & (ttl_oeph > 2)]
stimulus_oeph = self.block.data_arrays["stimulus"]
stimulus = self.relacs_block.data_arrays["GlobalEFieldStimulus"]
plt.plot(np.arange(stimulus.size) / 20_000.0, stimulus, label="relacs-stimulus")
plt.plot(
np.arange(len(stimulus_oeph)) / 30_000.0,
stimulus_oeph[:],
label="open-ephys",
)
plt.plot(
np.arange(len(ttl_oeph)) / 30_000.0,
ttl_oeph[:],
label="ttl-line",
)
plt.scatter(
np.arange(len(ttl_oeph))[peaks_ttl] / 30_000.0,
ttl_oeph[peaks_ttl],
label="detected peaks",
color="red",
zorder=100,
)
plt.legend(loc="upper right")
plt.show()
def close(self):
self.dataset.close()
self.nix_file.close()
self.relacs_nix_file.close()