# -*- coding: utf-8 -*- """ Created on Tue Oct 22 11:43:41 2024 @author: diana """ import glob import os import rlxnix as rlx import numpy as np import matplotlib.pyplot as plt import scipy.signal as sig from scipy.integrate import quad ### FUNCTIONS ### def binary_spikes(spike_times, duration, dt): """ Converts the spike times to a binary representation. Zeros when there is no spike, One when there is. Parameters ---------- spike_times : np.array The spike times. duration : float The trial duration. dt : float the temporal resolution. Returns ------- binary : np.array The binary representation of the spike times. """ binary = np.zeros(int(np.round(duration / dt))) #Vektor, der genauso lang ist wie die stim time spike_indices = np.asarray(np.round(spike_times / dt), dtype=int) binary[spike_indices] = 1 return binary def firing_rate(binary_spikes, box_width, dt=0.000025): """Calculate the firing rate from binary spike data. This function computes the firing rate using a boxcar (moving average) filter of a specified width. Parameters ---------- binary_spikes : np.array A binary array representing spike occurrences. box_width : float The width of the box filter in seconds. dt : float, optional The temporal resolution (time step) in seconds. Default is 0.000025 seconds. Returns ------- rate : np.array An array representing the firing rate at each time step. """ box = np.ones(int(box_width // dt)) box /= np.sum(box) * dt #Normalization of box kernel to an integral of 1 rate = np.convolve(binary_spikes, box, mode="same") return rate def powerspectrum(rate, dt): """Compute the power spectrum of a given firing rate. This function calculates the power spectrum using the Welch method. Parameters ---------- rate : np.array An array of firing rates. dt : float The temporal resolution (time step) in seconds. Returns ------- frequency : np.array An array of frequencies corresponding to the power values. power : np.array An array of power spectral density values. """ frequency, power = sig.welch(rate, fs=1/dt, nperseg=2**15, noverlap=2**14) return frequency, power def prepare_harmonics(frequencies, categories, num_harmonics, colors): points_categories = {} for idx, (freq, category) in enumerate(zip(frequencies, categories)): points_categories[category] = [freq * (i + 1) for i in range(num_harmonics[idx])] points = [p for harmonics in points_categories.values() for p in harmonics] color_mapping = {category: colors[idx] for idx, category in enumerate(categories)} return points, color_mapping, points_categories def plot_power_spectrum_with_integrals(frequency, power, points, delta): """ Create a figure of the power spectrum and calculate integrals around specified points. This function generates the plot of the power spectrum and calculates integrals around specified harmonic points, but it does not color the regions or add vertical lines. Parameters ---------- frequency : np.array An array of frequencies corresponding to the power values. power : np.array An array of power spectral density values. points : list A list of harmonic frequencies to highlight. delta : float Half-width of the range for integration around each point. Returns ------- integrals : list List of calculated integrals for each point. local_means : list List of local mean values (adjacent integrals). fig : matplotlib.figure.Figure The created figure object with the power plot. ax : matplotlib.axes.Axes The axes object for further modifications. """ fig, ax = plt.subplots() ax.plot(frequency, power) # Plot power spectrum integrals = [] local_means = [] for point in points: # Define indices for the integration window indices = (frequency >= point - delta) & (frequency <= point + delta) # Calculate integral around the point integral = np.trapz(power[indices], frequency[indices]) integrals.append(integral) # Calculate adjacent region integrals for local mean left_indices = (frequency >= point - 5 * delta) & (frequency < point - delta) right_indices = (frequency > point + delta) & (frequency <= point + 5 * delta) l_integral = np.trapz(power[left_indices], frequency[left_indices]) r_integral = np.trapz(power[right_indices], frequency[right_indices]) local_mean = np.mean([l_integral, r_integral]) local_means.append(local_mean) ax.set_xlim([0, 1200]) # Set x-axis limit ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power') ax.set_title('Power Spectrum with Integrals (Uncolored)') return integrals, local_means, fig, ax def highlight_integrals_with_threshold(frequency, power, points, delta, threshold, integrals, local_means, color_mapping, points_categories, fig_orig, ax_orig): """ Create a new figure by highlighting integrals that exceed the threshold. This function generates a new figure with colored shading around points where the integrals exceed the local mean by a given threshold and adds vertical lines at the boundaries of adjacent regions. It leaves the original figure unchanged. Parameters ---------- frequency : np.array An array of frequencies corresponding to the power values. power : np.array An array of power spectral density values. points : list A list of harmonic frequencies to highlight. delta : float Half-width of the range for integration around each point. threshold : float Threshold value to compare integrals with local mean. integrals : list List of calculated integrals for each point. local_means : list List of local mean values (adjacent integrals). color_mapping : dict A mapping of point categories to colors. points_categories : dict A mapping of categories to lists of points. fig_orig : matplotlib.figure.Figure The original figure object (remains unchanged). ax_orig : matplotlib.axes.Axes The original axes object (remains unchanged). Returns ------- fig_new : matplotlib.figure.Figure The new figure object with color highlights and vertical lines. """ # Create a new figure based on the original power spectrum fig_new, ax_new = plt.subplots() ax_new.plot(frequency, power) # Plot the same power spectrum # Loop through each point and check if the integral exceeds the threshold for i, point in enumerate(points): exceeds = integrals[i] > (local_means[i] * threshold) if exceeds: # Define color based on the category of the point color = next((c for cat, c in color_mapping.items() if point in points_categories[cat]), 'gray') # Shade the region around the point where the integral was calculated ax_new.axvspan(point - delta, point + delta, color=color, alpha=0.3, label=f'{point:.2f} Hz') print(f"Integral around {point:.2f} Hz: {integrals[i]:.5e}") # Define left and right boundaries of adjacent regions left_boundary = frequency[np.where((frequency >= point - 5 * delta) & (frequency < point - delta))[0][0]] right_boundary = frequency[np.where((frequency > point + delta) & (frequency <= point + 5 * delta))[0][-1]] # Add vertical dashed lines at the boundaries of the adjacent regions ax_new.axvline(x=left_boundary, color="k", linestyle="--") ax_new.axvline(x=right_boundary, color="k", linestyle="--") # Update plot legend and return the new figure ax_new.set_xlim([0, 1200]) ax_new.set_xlabel('Frequency (Hz)') ax_new.set_ylabel('Power') ax_new.set_title('Power Spectrum with Highlighted Integrals') ax_new.legend() return fig_new ### Data retrieval ### datafolder = "../data" # Geht in der Hierarchie einen Ordern nach oben (..) und dann in den Ordner 'data' example_file = os.path.join("..", "data", "2024-10-16-ad-invivo-1.nix") dataset = rlx.Dataset(example_file) sams = dataset.repro_runs("SAM") sam = sams[2] ## Daten für Funktionen df = sam.metadata["RePro-Info"]["settings"]["deltaf"][0][0] stim = sam.stimuli[1] potential, time = stim.trace_data("V-1") spikes, _ = stim.trace_data("Spikes-1") duration = stim.duration dt = stim.trace_info("V-1").sampling_interval ### Anwendung Functionen ### b = binary_spikes(spikes, duration, dt) rate = firing_rate(b, box_width=0.05, dt=dt) frequency, power = powerspectrum(b, dt) ## Important stuff eodf = stim.metadata[stim.name]["EODf"][0][0] stimulus_frequency = eodf + df AM = 50 # Hz #print(f"EODf: {eodf}, Stimulus Frequency: {stimulus_frequency}, AM: {AM}") frequencies = [AM, eodf, stimulus_frequency] categories = ["AM", "EODf", "Stimulus frequency"] num_harmonics = [4, 2, 2] colors = ["green", "orange", "red"] delta = 2.5 threshold = 10 ### points, color_mapping, points_categories = prepare_harmonics(frequencies, categories, num_harmonics, colors) # First, create the power spectrum plot with integrals (without coloring) integrals, local_means, fig1, ax1 = plot_power_spectrum_with_integrals(frequency, power, points, delta) # Then, create a new separate figure where integrals exceeding the threshold are highlighted fig2 = highlight_integrals_with_threshold(frequency, power, points, delta, threshold, integrals, local_means, color_mapping, points_categories, fig1, ax1)