import numpy as np import plotly.graph_objects as go import scipy.signal as signal from plotly.subplots import make_subplots def trial_plot(repro_d, repro_r, x_lim: int = 1.0): sinus, t = repro_d.trace_data("sinus") sinus_r, t_r = repro_r.trace_data("V-1") stimulus_oe, t = repro_d.trace_data("stimulus") stimulus_re, t_r = repro_r.trace_data("GlobalEFieldStimulus") local_eod_oe, t = repro_d.trace_data("local-eod") local_eod_re, t_r = repro_r.trace_data("LocalEOD-1") global_eod_oe, t = repro_d.trace_data("global-eod") global_eod_re, t_r = repro_r.trace_data("EOD") ttl, t = repro_d.trace_data("ttl-line") mask = t < x_lim mask_r = t_r < x_lim t = t[mask] t_r = t_r[mask_r] sinus = sinus[mask] sinus_r = sinus_r[mask_r] stimulus_oe = stimulus_oe[mask] stimulus_re = stimulus_re[mask_r] local_eod_oe = local_eod_oe[mask] local_eod_re = local_eod_re[mask_r] global_eod_oe = global_eod_oe[mask] global_eod_re = global_eod_re[mask_r] ttl = ttl[mask] fig = make_subplots( rows=5, cols=1, shared_xaxes=True, subplot_titles=( "TTL-Line", "Stimulus", "Local EOD", "Global EOD", "Sinus", ), ) fig.add_trace( go.Scattergl(x=t, y=ttl, name="ttl-line", line_color="magenta"), row=1, col=1, ) fig.add_trace( go.Scattergl(x=t_r, y=stimulus_re, line_color="blue"), row=2, col=1, ) fig.add_trace( go.Scattergl( x=t, y=stimulus_oe - np.mean(stimulus_oe), # The same data transformation name="stimulus (open-ephys)", line_color="red", ), row=2, col=1, ) # 3. Add traces to the SECOND subplot (row=2, col=1) fig.add_trace( go.Scattergl(x=t_r, y=local_eod_re, line_color="blue", showlegend=False), row=3, col=1, ) fig.add_trace( go.Scattergl(x=t, y=local_eod_oe, showlegend=False, line_color="red"), row=3, col=1, ) # 4. Add traces to the THIRD subplot (row=3, col=1) fig.add_trace( go.Scattergl(x=t_r, y=global_eod_re, showlegend=False, line_color="blue"), row=4, col=1, ) fig.add_trace( go.Scattergl(x=t, y=global_eod_oe, showlegend=False, line_color="red"), row=4, col=1, ) fig.add_trace( go.Scattergl(x=t_r, y=sinus_r, showlegend=False, line_color="blue"), row=5, col=1, ) fig.add_trace( go.Scattergl(x=t, y=sinus, showlegend=False, line_color="red"), row=5, col=1, ) # 6. Update the layout for a cleaner look fig.update_layout( template="plotly_dark", height=800, # Set the figure height in pixels # Control the legend legend=dict( bgcolor="rgba(0,0,0,0)", # transparent dark (or use "#1f2630" to match bg) bordercolor="#444", borderwidth=0, font=dict(color="#e5ecf6"), # matches plotly_dark foreground orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.72, ), ) # Add a label to the shared x-axis (targeting the last subplot) fig.update_xaxes(title_text="Time (s)", row=4, col=1) fig.update_xaxes(range=[0, x_lim]) return fig def plot_line_comparision( time_relacs, time_oephys, data_relacs, data_oephys, labels, ): x_lim = 1.0 mask = time_oephys < x_lim mask_r = time_relacs < x_lim time_oephys = time_oephys[mask] time_relacs = time_relacs[mask_r] data_oephys = data_oephys[mask] data_relacs = data_relacs[mask_r] fig = go.Figure() fig.add_trace( go.Scattergl( x=time_relacs, y=data_relacs, name=labels[0], line_color="blue", mode="lines+markers", ) ) fig.add_trace( go.Scattergl( x=time_oephys, y=data_oephys, name=labels[1], line_color="red", mode="lines+markers", ) ) fig.update_layout( template="plotly_dark", height=500, # Set the figure height in pixels legend=dict( bgcolor="rgba(0,0,0,0)", bordercolor="#444", borderwidth=0, font_color="#e5ecf6", orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.72, ), ) fig.update_xaxes(title_text="Time (s)", range=[0, 0.01]) return fig def calc_lag(repro_d, repro_r): sinus, t = repro_d.trace_data("sinus") sinus_r, t_r = repro_r.trace_data("V-1") stimulus_oe, t = repro_d.trace_data("stimulus") stimulus_re, t_r = repro_r.trace_data("GlobalEFieldStimulus") local_eod_oe, t = repro_d.trace_data("local-eod") local_eod_re, t_r = repro_r.trace_data("LocalEOD-1") global_eod_oe, t = repro_d.trace_data("global-eod") global_eod_re, t_r = repro_r.trace_data("EOD") oephys_lanes = [sinus, local_eod_oe, global_eod_oe, stimulus_oe] relacs_lanes = [sinus_r, local_eod_re, global_eod_re, stimulus_re] lags_lanes = [] for oephys_lane, relacs_lane in zip(oephys_lanes, relacs_lanes, strict=True): oephys_lane_resampled = signal.resample(oephys_lane, len(relacs_lane)) correlation = signal.correlate(oephys_lane_resampled, relacs_lane, mode="full") lags = signal.correlation_lags(oephys_lane_resampled.size, relacs_lane.size, mode="full") lag = lags[np.argmax(correlation)] lags_lanes.append(lag) return lags_lanes