# -*- coding: utf-8 -*- """ Created on Tue Oct 22 15:21:41 2024 @author: diana """ import glob import os import rlxnix as rlx import numpy as np import matplotlib.pyplot as plt import scipy.signal as sig from scipy.integrate import quad ### FUNCTIONS ### def binary_spikes(spike_times, duration, dt): """Converts the spike times to a binary representation. Zeros when there is no spike, one when there is. Parameters ---------- spike_times : np.array The spike times. duration : float The trial duration. dt : float The temporal resolution. Returns ------- binary : np.array The binary representation of the spike times. """ binary = np.zeros(int(np.round(duration / dt))) #Vektor, der genauso lang ist wie die stim time spike_indices = np.asarray(np.round(spike_times / dt), dtype=int) binary[spike_indices] = 1 return binary def firing_rate(binary_spikes, box_width, dt=0.000025): """Calculate the firing rate from binary spike data. Parameters ---------- binary_spikes : np.array A binary array representing spike occurrences. box_width : float The width of the box filter in seconds. dt : float, optional The temporal resolution (time step) in seconds. Default is 0.000025 seconds. Returns ------- rate : np.array An array representing the firing rate at each time step. """ box = np.ones(int(box_width // dt)) box /= np.sum(box) * dt # Normalization of box kernel to an integral of 1 rate = np.convolve(binary_spikes, box, mode="same") return rate def powerspectrum(rate, dt): """Compute the power spectrum of a given firing rate. This function calculates the power spectrum using the Welch method. Parameters ---------- rate : np.array An array of firing rates. dt : float The temporal resolution (time step) in seconds. Returns ------- frequency : np.array An array of frequencies corresponding to the power values. power : np.array An array of power spectral density values. """ frequency, power = sig.welch(rate, fs=1/dt, nperseg=2**15, noverlap=2**14) return frequency, power def calculate_integral(frequency, power, point, delta): """ Calculate the integral around a single specified point. Parameters ---------- frequency : np.array An array of frequencies corresponding to the power values. power : np.array An array of power spectral density values. point : float The harmonic frequency at which to calculate the integral. delta : float Half-width of the range for integration around the point. Returns ------- integral : float The calculated integral around the point. local_mean : float The local mean value (adjacent integrals). """ indices = (frequency >= point - delta) & (frequency <= point + delta) integral = np.trapz(power[indices], frequency[indices]) left_indices = (frequency >= point - 5 * delta) & (frequency < point - delta) right_indices = (frequency > point + delta) & (frequency <= point + 5 * delta) l_integral = np.trapz(power[left_indices], frequency[left_indices]) r_integral = np.trapz(power[right_indices], frequency[right_indices]) local_mean = np.mean([l_integral, r_integral]) return integral, local_mean def valid_integrals(integral, local_mean, threshold, point): """ Check if the integral exceeds the threshold compared to the local mean and provide feedback on whether the given point is valid or not. Parameters ---------- integral : float The calculated integral around the point. local_mean : float The local mean value (adjacent integrals). threshold : float Threshold value to compare integrals with local mean. point : float The harmonic frequency point being evaluated. Returns ------- valid : bool True if the integral exceeds the local mean by the threshold, otherwise False. message : str A message stating whether the point is valid or not. """ valid = integral > (local_mean * threshold) if valid: message = f"The point {point} is valid, as its integral exceeds the threshold." else: message = f"The point {point} is not valid, as its integral does not exceed the threshold." return valid, message def prepare_harmonics(frequencies, categories, num_harmonics, colors): """ Prepare harmonic frequencies and assign colors based on categories. Parameters ---------- frequencies : list Base frequencies to generate harmonics. categories : list Corresponding categories for the base frequencies. num_harmonics : list Number of harmonics for each base frequency. colors : list List of colors corresponding to the categories. Returns ------- points : list A flat list of harmonic frequencies. color_mapping : dict A dictionary mapping each category to its corresponding color. points_categories : dict A mapping of categories to their harmonic frequencies. """ points_categories = {} for idx, (freq, category) in enumerate(zip(frequencies, categories)): points_categories[category] = [freq * (i + 1) for i in range(num_harmonics[idx])] points = [p for harmonics in points_categories.values() for p in harmonics] color_mapping = {category: colors[idx] for idx, category in enumerate(categories)} return points, color_mapping, points_categories def find_exceeding_points(frequency, power, points, delta, threshold): """ Find the points where the integral exceeds the local mean by a given threshold. Parameters ---------- frequency : np.array An array of frequencies corresponding to the power values. power : np.array An array of power spectral density values. points : list A list of harmonic frequencies to evaluate. delta : float Half-width of the range for integration around the point. threshold : float Threshold value to compare integrals with local mean. Returns ------- exceeding_points : list A list of points where the integral exceeds the local mean by the threshold. """ exceeding_points = [] for point in points: # Calculate the integral and local mean for the current point integral, local_mean = calculate_integral(frequency, power, point, delta) # Check if the integral exceeds the threshold valid, message = valid_integrals(integral, local_mean, threshold, point) if valid: exceeding_points.append(point) return exceeding_points def plot_highlighted_integrals(frequency, power, exceeding_points, delta, threshold, color_mapping, points_categories): """ Plot the power spectrum and highlight integrals that exceed the threshold. Parameters ---------- frequency : np.array An array of frequencies corresponding to the power values. power : np.array An array of power spectral density values. exceeding_points : list A list of harmonic frequencies that exceed the threshold. delta : float Half-width of the range for integration around each point. threshold : float Threshold value to compare integrals with local mean. color_mapping : dict A dictionary mapping each category to its color. points_categories : dict A mapping of categories to lists of points. Returns ------- fig : matplotlib.figure.Figure The created figure object with highlighted integrals. """ fig, ax = plt.subplots() ax.plot(frequency, power) # Plot power spectrum for point in exceeding_points: integral, local_mean = calculate_integral(frequency, power, point, delta) valid, _ = valid_integrals(integral, local_mean, threshold, point) if valid: # Define color based on the category of the point color = next((c for cat, c in color_mapping.items() if point in points_categories[cat]), 'gray') # Shade the region around the point where the integral was calculated ax.axvspan(point - delta, point + delta, color=color, alpha=0.3, label=f'{point:.2f} Hz') print(f"Integral around {point:.2f} Hz: {integral:.5e}") # Define left and right boundaries of adjacent regions left_boundary = frequency[np.where((frequency >= point - 5 * delta) & (frequency < point - delta))[0][0]] right_boundary = frequency[np.where((frequency > point + delta) & (frequency <= point + 5 * delta))[0][-1]] # Add vertical dashed lines at the boundaries of the adjacent regions ax.axvline(x=left_boundary, color="k", linestyle="--") ax.axvline(x=right_boundary, color="k", linestyle="--") ax.set_xlim([0, 1200]) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power') ax.set_title('Power Spectrum with Highlighted Integrals') ax.legend() return fig ### Data retrieval ### datafolder = "../data" example_file = os.path.join("..", "data", "2024-10-16-ad-invivo-1.nix") dataset = rlx.Dataset(example_file) sams = dataset.repro_runs("SAM") sam = sams[2] ## Data for functions df = sam.metadata["RePro-Info"]["settings"]["deltaf"][0][0] stim = sam.stimuli[1] potential, time = stim.trace_data("V-1") spikes, _ = stim.trace_data("Spikes-1") duration = stim.duration dt = stim.trace_info("V-1").sampling_interval ### Apply Functions to calculate data ### b = binary_spikes(spikes, duration, dt) rate = firing_rate(b, box_width=0.05, dt=dt) frequency, power = powerspectrum(b, dt) ### Important stuff ### ## Frequencies eodf = stim.metadata[stim.name]["EODf"][0][0] stimulus_frequency = eodf + df AM = 50 # Hz frequencies = [AM, eodf, stimulus_frequency] categories = ["AM", "EODf", "Stimulus frequency"] num_harmonics = [4, 2, 2] colors = ["green", "orange", "red"] delta = 2.5 threshold = 10 ### Apply functions to make powerspectrum ### integral, local = calculate_integral(frequency, power, eodf, delta) valid = valid_integrals(integral, local, threshold, eodf) points, color, categories = prepare_harmonics(frequencies, categories, num_harmonics, colors) print(len(points)) exceeding = find_exceeding_points(frequency, power, points, delta, threshold) print(len(exceeding)) ## Plot power spectrum and highlight integrals fig = plot_highlighted_integrals(frequency, power, points, delta, threshold, color, categories) plt.show()