oephys2nix/doc/util.py

210 lines
5.7 KiB
Python

import numpy as np
import plotly.graph_objects as go
import scipy.signal as signal
from plotly.subplots import make_subplots
def trial_plot(repro_d, repro_r, x_lim: int = 1.0):
sinus, t = repro_d.trace_data("sinus")
sinus_r, t_r = repro_r.trace_data("V-1")
stimulus_oe, t = repro_d.trace_data("stimulus")
stimulus_re, t_r = repro_r.trace_data("GlobalEFieldStimulus")
local_eod_oe, t = repro_d.trace_data("local-eod")
local_eod_re, t_r = repro_r.trace_data("LocalEOD-1")
global_eod_oe, t = repro_d.trace_data("global-eod")
global_eod_re, t_r = repro_r.trace_data("EOD")
ttl, t = repro_d.trace_data("ttl-line")
mask = t < x_lim
mask_r = t_r < x_lim
t = t[mask]
t_r = t_r[mask_r]
sinus = sinus[mask]
sinus_r = sinus_r[mask_r]
stimulus_oe = stimulus_oe[mask]
stimulus_re = stimulus_re[mask_r]
local_eod_oe = local_eod_oe[mask]
local_eod_re = local_eod_re[mask_r]
global_eod_oe = global_eod_oe[mask]
global_eod_re = global_eod_re[mask_r]
ttl = ttl[mask]
fig = make_subplots(
rows=5,
cols=1,
shared_xaxes=True,
subplot_titles=(
"TTL-Line",
"Stimulus",
"Local EOD",
"Global EOD",
"Sinus",
),
)
fig.add_trace(
go.Scattergl(x=t, y=ttl, name="ttl-line", line_color="magenta"),
row=1,
col=1,
)
fig.add_trace(
go.Scattergl(x=t_r, y=stimulus_re, line_color="blue"),
row=2,
col=1,
)
fig.add_trace(
go.Scattergl(
x=t,
y=stimulus_oe - np.mean(stimulus_oe), # The same data transformation
name="stimulus (open-ephys)",
line_color="red",
),
row=2,
col=1,
)
# 3. Add traces to the SECOND subplot (row=2, col=1)
fig.add_trace(
go.Scattergl(x=t_r, y=local_eod_re, line_color="blue", showlegend=False),
row=3,
col=1,
)
fig.add_trace(
go.Scattergl(x=t, y=local_eod_oe, showlegend=False, line_color="red"),
row=3,
col=1,
)
# 4. Add traces to the THIRD subplot (row=3, col=1)
fig.add_trace(
go.Scattergl(x=t_r, y=global_eod_re, showlegend=False, line_color="blue"),
row=4,
col=1,
)
fig.add_trace(
go.Scattergl(x=t, y=global_eod_oe, showlegend=False, line_color="red"),
row=4,
col=1,
)
fig.add_trace(
go.Scattergl(x=t_r, y=sinus_r, showlegend=False, line_color="blue"),
row=5,
col=1,
)
fig.add_trace(
go.Scattergl(x=t, y=sinus, showlegend=False, line_color="red"),
row=5,
col=1,
)
# 6. Update the layout for a cleaner look
fig.update_layout(
template="plotly_dark",
height=800, # Set the figure height in pixels
# Control the legend
legend=dict(
bgcolor="rgba(0,0,0,0)", # transparent dark (or use "#1f2630" to match bg)
bordercolor="#444",
borderwidth=0,
font=dict(color="#e5ecf6"), # matches plotly_dark foreground
orientation="h",
yanchor="bottom",
y=1.05,
xanchor="right",
x=0.72,
),
)
# Add a label to the shared x-axis (targeting the last subplot)
fig.update_xaxes(title_text="Time (s)", row=4, col=1)
fig.update_xaxes(range=[0, x_lim])
return fig
def plot_line_comparision(
time_relacs,
time_oephys,
data_relacs,
data_oephys,
labels,
):
x_lim = 1.0
mask = time_oephys < x_lim
mask_r = time_relacs < x_lim
time_oephys = time_oephys[mask]
time_relacs = time_relacs[mask_r]
data_oephys = data_oephys[mask]
data_relacs = data_relacs[mask_r]
fig = go.Figure()
fig.add_trace(
go.Scattergl(
x=time_relacs,
y=data_relacs,
name=labels[0],
line_color="blue",
mode="lines+markers",
)
)
fig.add_trace(
go.Scattergl(
x=time_oephys,
y=data_oephys,
name=labels[1],
line_color="red",
mode="lines+markers",
)
)
fig.update_layout(
template="plotly_dark",
height=500, # Set the figure height in pixels
legend=dict(
bgcolor="rgba(0,0,0,0)",
bordercolor="#444",
borderwidth=0,
font_color="#e5ecf6",
orientation="h",
yanchor="bottom",
y=1.05,
xanchor="right",
x=0.72,
),
)
fig.update_xaxes(title_text="Time (s)", range=[0, 0.01])
return fig
def calc_lag(repro_d, repro_r):
sinus, t = repro_d.trace_data("sinus")
sinus_r, t_r = repro_r.trace_data("V-1")
stimulus_oe, t = repro_d.trace_data("stimulus")
stimulus_re, t_r = repro_r.trace_data("GlobalEFieldStimulus")
local_eod_oe, t = repro_d.trace_data("local-eod")
local_eod_re, t_r = repro_r.trace_data("LocalEOD-1")
global_eod_oe, t = repro_d.trace_data("global-eod")
global_eod_re, t_r = repro_r.trace_data("EOD")
oephys_lanes = [sinus, local_eod_oe, global_eod_oe, stimulus_oe]
relacs_lanes = [sinus_r, local_eod_re, global_eod_re, stimulus_re]
lags_lanes = []
for oephys_lane, relacs_lane in zip(oephys_lanes, relacs_lanes, strict=True):
oephys_lane_resampled = signal.resample(oephys_lane, len(relacs_lane))
correlation = signal.correlate(oephys_lane_resampled, relacs_lane, mode="full")
lags = signal.correlation_lags(oephys_lane_resampled.size, relacs_lane.size, mode="full")
lag = lags[np.argmax(correlation)]
lags_lanes.append(lag)
return lags_lanes