Files
oephys2nix/oephys2nix/stimulus_recreation.py
2025-10-21 15:06:53 +02:00

674 lines
25 KiB
Python

import logging
import pathlib
import sys
import matplotlib.pyplot as plt
import nixio
import numpy as np
import rlxnix as rlx
from IPython import embed
from neo.io import OpenEphysBinaryIO
from nixio.exceptions import DuplicateName
from rich.console import Console
from rich.table import Table
from rlxnix.plugins.efish.utils import extract_am
from scipy import signal
from oephys2nix.metadata import create_dict_from_section, create_metadata_from_dict
log = logging.getLogger(__name__)
console = Console()
class StimulusToNix:
"""Processing the stimulus recreation from the relax dataset and the open-ephys data.
Parameters
----------
open_ephys_path: pathlib.Path
Path to open-ephys recording
relacs_nix_path : str
Path to relacs nix file
nix_file : str
Path to new nix file
Attributes
----------
relacs_nix_path : str
Path to relacs nix file
nix_file_path : str
Path to new nix file
dataset : rlx.Dataset
Dataset from the relacs file
relacs_nix_file : nixio.File
Relacs nix file
relacs_block : nixio.Block
Relacs nix block
relacs_sections :nixio.Section
Relacs nix Section
neo_data : neo.OpenEphysBinaryIO
Open-ephys data
fs : float
Sample rate of the open-ephys recording
nix_file : nixio.File
New nix file
block : nixio.Block
New nix block
threshold : float
Threshold for TTL line
new_start_jiggle : float
For finding the new start, ensuring finding next TTL pulse
"""
def __init__(self, open_ephys_path: pathlib.Path, relacs_nix_path: str, nix_file: str):
self.relacs_nix_path = relacs_nix_path
self.nix_file_path = nix_file
self.dataset = rlx.Dataset(relacs_nix_path)
self.relacs_nix_file = nixio.File.open(relacs_nix_path, nixio.FileMode.ReadOnly)
self.relacs_block = self.relacs_nix_file.blocks[0]
self.relacs_sections = self.relacs_nix_file.sections
self.neo_data = OpenEphysBinaryIO(open_ephys_path).read(lazy=True)
self.fs = self.neo_data[0].segments[0].analogsignals[0].sampling_rate.magnitude
self.nix_file = nixio.File(nix_file, nixio.FileMode.ReadWrite)
self.block = self.nix_file.blocks[0]
# Threshold for TTL peak
# constants
self.threshold = 2
self.new_start_jiggle = 0.1
def _append_relacs_tag_mtags(self) -> None:
"""Append relacs tags and multi tags to new nix file."""
for t in self.relacs_block.tags:
log.debug(f"Appending relacs tags {t.name}")
tag = self.block.create_tag(f"relacs_{t.name}", t.type, position=t.position)
tag.extent = t.extent
tag.references.extend(self.block.groups["relacs"].data_arrays)
sec = self.relacs_sections[f"{t.name}"]
d = create_dict_from_section(sec)
try:
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
except DuplicateName:
pass
tag.metadata = self.nix_file.sections[sec.name]
for t in self.relacs_block.multi_tags:
log.debug(f"Appending relacs multi-tags {t.name}")
mtag = self.block.create_multi_tag(
f"relacs_{t.name}",
t.type,
positions=t.positions[:],
extents=t.extents[:],
)
mtag.references.extend(self.block.groups["relacs"].data_arrays)
sec = self.relacs_sections[f"{t.name}"]
d = create_dict_from_section(sec)
try:
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
except DuplicateName:
pass
mtag.metadata = self.nix_file.sections[sec.name]
def _find_next_ttl(
self, time_ttl: np.ndarray, peaks_ttl: np.ndarray, lower: float, upper: float
) -> np.ndarray:
"""Find the next TTL pulse within a specific duration constrained by lower and upper.
Parameters
----------
time_ttl : np.ndarray
Time array of the TTL line
peaks_ttl : np.ndarray
Detected peaks indeces on the TTL line
lower : float
lower bound for searching the TTL pulse
upper : float
upper bound for searching the TTL pulse
Returns
-------
np.ndarray
time point of the new TTL pulse
"""
peak = time_ttl[peaks_ttl[(time_ttl[peaks_ttl] > lower) & (time_ttl[peaks_ttl] < upper)]]
if not peak.size > 0:
log.error("No peaks found")
elif peak.size > 1:
if np.all(np.diff(peak) > 0.5):
log.error("Peaks are further aways than 0.5 seconds")
log.error(f"Peaks {peak}, Furthest aways: {np.max(peak)}")
peak = np.mean(peak)
return peak
def _find_peak_ttl_index(
self, time_ttl: np.ndarray, peaks_ttl: np.ndarray, current_position: np.ndarray
) -> np.ndarray:
"""Find the next TTL pulse from the indeces of the detected TTL pulses.
Parameters
----------
time_ttl : np.ndarray
Time array of the TTL line
peaks_ttl : np.ndarray
Detected peaks indeces on the TTL line
current_position : np.ndarray
Current time of the TTL pulse
Returns
-------
np.ndarray
Next time of TTL pulse
"""
new_repro_start_index = peaks_ttl[
(time_ttl[peaks_ttl] > current_position - self.new_start_jiggle)
& (time_ttl[peaks_ttl] < current_position + self.new_start_jiggle)
]
if new_repro_start_index.size > 1:
log.warning("Multiple current positions taking the last index")
new_repro_start_index = new_repro_start_index[-1]
if np.where(new_repro_start_index == peaks_ttl)[0] + 1 == len(peaks_ttl):
return np.array([])
else:
next_repro_start = peaks_ttl[np.where(new_repro_start_index == peaks_ttl)[0] + 1]
start_next_repro = time_ttl[next_repro_start]
log.debug(f"Start of new repro/trial {start_next_repro}")
return start_next_repro
@property
def _reference_groups(self) -> list[nixio.Group]:
"""Holds the reference groups.
Returns
-------
list[nixio.Group]
"""
return [
self.block.groups["neuronal-data"],
self.block.groups["efish"],
self.block.groups["relacs"],
]
def _append_mtag(self, repro: rlx.Dataset, positions: np.ndarray, extents: np.ndarray) -> None:
"""Apped multi tags of the current repro to the nix file.
Parameters
----------
repro : rlx.Dataset
Current Repro
positions : np.ndarray
postions of the multi tags
extents : np.ndarray
extents of the multi tags
"""
try:
nix_mtag = self.block.create_multi_tag(
f"{repro.name}",
"relacs.stimulus",
positions=positions,
extents=extents,
)
except DuplicateName:
del self.block.multi_tags[repro.name]
nix_mtag = self.block.create_multi_tag(
f"{repro.name}",
"relacs.stimulus",
positions=positions,
extents=extents,
)
sec = self.relacs_sections[f"{repro.name}"]
d = create_dict_from_section(sec)
try:
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
except DuplicateName:
del self.nix_file.sections[sec.name]
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
nix_mtag.metadata = self.nix_file.sections[repro.name]
[
nix_mtag.references.extend(ref_groups.data_arrays)
for ref_groups in self._reference_groups
if ref_groups.data_arrays
]
try:
nix_group = self.block.create_group(repro.name, repro.type)
except DuplicateName:
nix_group = self.nix_file.blocks[0].groups[repro.name]
nix_group.multi_tags.append(nix_mtag)
def _append_tag(self, repro: rlx.Dataset, position: np.ndarray, extent: np.ndarray) -> None:
"""Append tag of the current repro.
Parameters
----------
repro : rlx.Dataset
Current Repro
position : np.ndarray
positions of the multi tags
extent : np.ndarray
extents of the multi tags
"""
try:
nix_tag = self.block.create_tag(
f"{repro.name}",
"relacs.repro_run",
position=np.array(position).flatten(),
)
nix_tag.extent = extent.flatten()
except DuplicateName:
del self.block.tags[repro.name]
nix_tag = self.block.create_tag(
f"{repro.name}",
"relacs.repro_run",
position=np.array(position).flatten(),
)
nix_tag.extent = extent.flatten()
sec = self.relacs_sections[f"{repro.name}"]
d = create_dict_from_section(sec)
try:
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
except DuplicateName:
del self.nix_file.sections[sec.name]
new_sec = self.nix_file.create_section(sec.name, sec.type)
create_metadata_from_dict(d, new_sec)
nix_tag.metadata = self.nix_file.sections[repro.name]
# NOTE: adding refs to tag
[
nix_tag.references.extend(ref_groups.data_arrays)
for ref_groups in self._reference_groups
if ref_groups.data_arrays
]
try:
nix_group = self.block.create_group(repro.name, repro.type)
except DuplicateName:
nix_group = self.nix_file.blocks[0].groups[repro.name]
nix_group.tags.append(nix_tag)
def create_repros_automatically(self) -> None:
"""Create the repros form relacs with the TTL pulses."""
ttl_oeph = self.block.data_arrays["ttl-line"][:]
time_ttl = np.arange(len(ttl_oeph)) / self.fs
time_index = np.arange(len(ttl_oeph))
peaks_ttl = time_index[
(np.roll(ttl_oeph, 1) < self.threshold) & (ttl_oeph > self.threshold)
]
# WARNING:Check if peaks are duplicates or near each other
close_peaks = np.where(np.diff(peaks_ttl) == 1)[0]
if close_peaks.size > 0:
peaks_ttl = np.delete(peaks_ttl, close_peaks)
first_peak = self._find_next_ttl(
time_ttl,
peaks_ttl,
time_ttl[peaks_ttl[0]] - self.new_start_jiggle,
time_ttl[peaks_ttl[0]] + self.new_start_jiggle,
)
current_position = np.asarray(first_peak.reshape(1))
for i, repro in enumerate(self.dataset.repro_runs()):
log.debug(repro.name)
log.debug(f"Current Position {current_position.item()}")
if repro.duration < 0.05:
log.warning(f"Skipping repro {repro.name} because it is two short")
continue
if repro.stimuli:
log.debug("Processing MultiTag")
repetition = len(repro.stimuli)
extents_mtag = np.zeros((repetition, 1))
position_mtags = np.zeros((repetition, 1))
for trial, stimulus in enumerate(repro.stimuli):
duration_trial = stimulus.duration
extents_mtag[trial] = duration_trial
position_mtags[trial] = current_position
current_position = self._find_peak_ttl_index(
time_ttl, peaks_ttl, current_position
)
if "FICurve" in repro.name:
position_mtags += 0.2
self._append_mtag(repro, position_mtags, extents_mtag)
extent = position_mtags[-1] + extents_mtag[-1] - position_mtags[0]
self._append_tag(repro, position_mtags[0], extent)
if not current_position.size > 0:
log.debug("Checking if it is the last repro")
if i != len(self.dataset.repro_runs()) - 1:
if "Baseline" in self.dataset.repro_runs()[-1].name:
log.debug("Appending last baseline Repro")
last_position = (
self.block.tags[repro.name].position[0]
+ self.block.tags[repro.name].extent[0]
)
lastrepro = self.dataset.repro_runs()[-1]
self._append_tag(lastrepro, last_position, lastrepro.duration)
else:
log.error("Last Repro was not appended")
break
log.info("Finishing writing")
log.info("Closing nix files")
break
else:
if i == 0 and "BaselineActivity" in repro.name:
self._append_tag(repro, 0.0, current_position)
continue
last_repro_name = self.dataset.repro_runs()[i - 1].name
last_repro_position = (
self.block.groups[last_repro_name].tags[0].position[0]
+ self.block.groups[last_repro_name].tags[0].extent[0]
)
self._append_tag(
repro,
last_repro_position.reshape(-1, 1),
(current_position - last_repro_position).reshape(-1, 1),
)
def create_repros_from_config_file(self) -> None:
"""Creates repros form a config file.
NOT MAINTAINED.
"""
ttl_oeph = self.block.data_arrays["ttl-line"][:]
peaks_ttl = signal.find_peaks(
ttl_oeph.flatten(),
**self.stimulus_config["stimulus"]["peak_detection"],
)[0]
time_ttl = np.arange(len(ttl_oeph)) / self.cfg.open_ephys.samplerate
number_of_repros = len(self.stimulus_config["repros"]["name"])
referencs_groups = [
self.block.groups["neuronal-data"],
self.block.groups["spike-data"],
self.block.groups["efish"],
]
for repro_index in range(number_of_repros):
name = self.stimulus_config["repros"]["name"][repro_index]
log.debug(name)
start = np.array(self.stimulus_config["repros"]["start"][repro_index])
end = np.array(self.stimulus_config["repros"]["end"][repro_index])
nix_group = self.block.groups[name]
if start.size > 1:
start_repro = time_ttl[
peaks_ttl[(time_ttl[peaks_ttl] > start[0]) & (time_ttl[peaks_ttl] < start[1])]
]
if start_repro.size > 1:
if np.all(np.diff(start_repro) > 0.005):
log.error("Wrong end point in end of repro")
log.error(f"{name[repro_index]}, {np.max(start_repro)}")
exit(1)
start_repro = np.mean(start_repro)
else:
start_repro = start_repro[0]
else:
start_repro = start[0]
if end.size > 1:
end_repro = time_ttl[
peaks_ttl[(time_ttl[peaks_ttl] > end[0]) & (time_ttl[peaks_ttl] < end[1])]
]
if end_repro.size > 1:
if np.all(np.diff(end_repro) > 0.005):
log.error("Wrong end point in end of repro")
log.error(f"{name[repro_index]}, {np.max(end_repro)}")
exit(1)
end_repro = np.mean(end_repro)
else:
end_repro = end[0]
nix_tag = self.block.create_tag(
f"{nix_group.name}",
f"{nix_group.type}",
position=[start_repro],
)
nix_tag.extent = [end_repro - start_repro]
nix_tag.metadata = self.nix_file.sections[name]
# NOTE: adding refs to tag
[
nix_tag.references.extend(ref_groups.data_arrays)
for ref_groups in referencs_groups
if ref_groups.data_arrays
]
nix_group.tags.append(nix_tag)
if not self.relacs_block.groups[name].multi_tags:
log.debug(f"no multitags in repro {name}, skipping")
continue
start_repro_multi_tag = start_repro.copy()
positions = []
positions.append(np.array(start_repro))
extents = []
extents_relacsed_nix = self.relacs_block.groups[name].multi_tags[0].extents[:]
extents.append(np.array(extents_relacsed_nix[0]))
index_multi_tag = 0
while start_repro_multi_tag < end_repro:
log.debug(f"{start_repro_multi_tag}")
start_repro_multi_tag = time_ttl[
peaks_ttl[
(
time_ttl[peaks_ttl]
> start_repro_multi_tag + extents_relacsed_nix[index_multi_tag]
)
& (
time_ttl[peaks_ttl]
< start_repro_multi_tag + extents_relacsed_nix[index_multi_tag] + 2
)
]
]
if start_repro_multi_tag.size == 0:
log.debug("Did not find any peaks for new start multi tag")
log.debug(
f"Differenz to end of repro {end_repro - (positions[-1] + extents_relacsed_nix[index_multi_tag])}"
)
break
if start_repro_multi_tag.size > 1:
if np.all(np.diff(start_repro_multi_tag) > 0.005):
log.error("Wrong end point in end of repro")
log.error(f"Repro_name: {name[repro_index]}")
log.error(f"multitag_index: {index_multi_tag}")
log.error(f"max_value {np.max(start_repro_multi_tag)}")
sys.exit(1)
start_repro_multi_tag = np.mean(start_repro_multi_tag)
else:
start_repro_multi_tag = start_repro_multi_tag[0]
if end_repro - start_repro_multi_tag < 1:
log.debug("Posssible endpoint detected")
log.debug(f"Differenz to repro end: {end_repro - start_repro_multi_tag}")
break
positions.append(np.array(start_repro_multi_tag))
extents.append(extents_relacsed_nix[index_multi_tag])
positions = np.array(positions).reshape(len(positions), 1)
extents = np.array(extents)
if positions.size != extents.size:
log.error("Calcualted positions and extents do not match ")
log.error(f"Shape positions {positions.shape}, shape extents {extents.shape}")
if positions.shape != extents.shape:
log.error("Shape of Calculated positions and extents do not match")
log.error(f"Shape positions {positions.shape}, shape extents {extents.shape}")
embed()
sys.exit(1)
nix_mtag = self.block.create_multi_tag(
f"{nix_group.name}",
f"{nix_group.type}",
positions=positions,
extents=extents,
)
nix_mtag.metadata = self.nix_file.sections[name]
# NOTE: adding refs to mtag
[
nix_mtag.references.extend(ref_groups.data_arrays)
for ref_groups in referencs_groups
if ref_groups.data_arrays
]
nix_group.multi_tags.append(nix_mtag)
def print_table(self) -> None:
"""Print the converted times in a rich table."""
nix_data_set = rlx.Dataset(self.nix_file_path)
table = Table("Repro Name", "start", "stop", "duration")
for repro_r, repro_n in zip(self.dataset.repro_runs(), nix_data_set.repro_runs()):
table.add_row(
f"{repro_r.name} -> [red]{repro_n.name}[/red]",
f"{repro_r.start_time:.2f} -> [red]{repro_n.start_time:.2f}[/red]",
f"{repro_r.stop_time:.2f} -> [red]{repro_n.stop_time:.2f}[/red]",
f"{repro_r.duration:.2f} -> [red]{repro_n.duration:.2f}[/red]",
)
console.print(table)
nix_data_set.close()
def checks(self) -> None:
"""Just for debugging currently."""
important_repros = ["FileStimulus", "SAM", "FICurve"]
nix_data_set = rlx.Dataset(self.nix_file_path)
for repro_r, repro_n in zip(self.dataset.repro_runs(), nix_data_set.repro_runs()):
if repro_n.name.split("_")[0] not in important_repros:
continue
if "FileStimulus" in repro_n.name:
repro_n.stimulus_folder = "/home/alexander/stimuli/whitenoise/"
repro_r.stimulus_folder = "/home/alexander/stimuli/whitenoise/"
white_noise, stimulus_time = repro_n.load_stimulus(0, 30_000)
white_noise_r, stimulus_time_r = repro_n.load_stimulus(0, 20_000)
stim = repro_n.stimuli[0]
stim_r = repro_r.stimuli[0]
# stim = repro_n
# stim_r = repro_r
sinus, t = stim.trace_data("sinus")
sinus_r, t_r = stim_r.trace_data("V-1")
stimulus_oe, t = stim.trace_data("stimulus")
stimulus_re, t_r = stim_r.trace_data("GlobalEFieldStimulus")
local_eod_oe, t = stim.trace_data("local-eod")
local_eod_re, t_r = stim_r.trace_data("LocalEOD-1")
global_eod_oe, t = stim.trace_data("global-eod")
global_eod_re, t_r = stim_r.trace_data("EOD")
ttl, t = stim.trace_data("ttl-line")
fig, ax = plt.subplots(4, sharex=True)
ax[0].plot(t_r, stimulus_re, color="tab:blue", label="relacs")
ax[0].plot(t, stimulus_oe - np.mean(stimulus_oe), color="tab:red", label="open-ephys")
ax[0].plot(t, ttl, color="black")
ax[1].plot(t_r, local_eod_re, color="tab:blue", label="relacs")
ax[1].plot(t, local_eod_oe, color="tab:red", label="open-ephys")
ax[2].plot(t_r, global_eod_re, color="tab:blue", label="relacs")
ax[2].plot(t, global_eod_oe, color="tab:red", label="open-ephys")
ax[3].plot(t_r, sinus_r, color="tab:blue", label="relacs")
ax[3].plot(t, sinus, color="tab:red", label="open-ephys")
plt.legend(loc="upper right")
# plt.plot(stimulus_time, white_noise, "o--", color="tab:orange", label="stimulus")
# plt.plot(new_t, new_v, "go--", label="resampled")
for i, (s, r) in enumerate(zip(repro_n.stimuli, repro_r.stimuli)):
v, t = s.trace_data("local-eod")
v_r, t_r = r.trace_data("LocalEOD-1")
plt.plot(t_r + (i * r.duration), v_r)
plt.plot(t + (i * s.duration), v)
plt.show()
am = extract_am(v, 30_000)
am_r = extract_am(v_r, 20_000)
plt.plot(t, am - np.mean(am) / np.max(np.abs(am)))
plt.plot(t_r, am_r - np.mean(am_r) / np.max(np.abs(am_r)))
v_eod, t_eod = repro_n.trace_data("global-eod")
v_eodr, t_eodr = repro_n.trace_data("EOD")
def plot_stimulus(self) -> None:
"""Plot the relacs stimulus, open-epyhs and TTL line."""
ttl_oeph = self.block.data_arrays["ttl-line"][:]
time_index = np.arange(len(ttl_oeph))
peaks_ttl = time_index[
(np.roll(ttl_oeph, 1) < self.threshold) & (ttl_oeph > self.threshold)
]
stimulus_oeph = self.block.data_arrays["stimulus"]
stimulus = self.relacs_block.data_arrays["GlobalEFieldStimulus"]
plt.plot(np.arange(stimulus.size) / 20_000.0, stimulus, label="relacs-stimulus")
plt.plot(
np.arange(len(stimulus_oeph)) / 30_000.0,
stimulus_oeph[:],
label="open-ephys",
)
plt.plot(
np.arange(len(ttl_oeph)) / 30_000.0,
ttl_oeph[:],
label="ttl-line",
)
plt.scatter(
np.arange(len(ttl_oeph))[peaks_ttl] / 30_000.0,
ttl_oeph[peaks_ttl],
label="detected peaks",
color="red",
zorder=100,
)
plt.legend(loc="upper right")
plt.show()
def close(self):
self.dataset.close()
self.nix_file.close()
self.relacs_nix_file.close()