This commit is contained in:
Diana 2024-10-25 17:09:46 +02:00
commit 136e8a380c
33 changed files with 48761 additions and 547 deletions

View File

@ -1,329 +0,0 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 22 15:21:41 2024
@author: diana
"""
import glob
import os
import rlxnix as rlx
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sig
from scipy.integrate import quad
### FUNCTIONS ###
def binary_spikes(spike_times, duration, dt):
"""Converts the spike times to a binary representation.
Zeros when there is no spike, one when there is.
Parameters
----------
spike_times : np.array
The spike times.
duration : float
The trial duration.
dt : float
The temporal resolution.
Returns
-------
binary : np.array
The binary representation of the spike times.
"""
binary = np.zeros(int(np.round(duration / dt))) #Vektor, der genauso lang ist wie die stim time
spike_indices = np.asarray(np.round(spike_times / dt), dtype=int)
binary[spike_indices] = 1
return binary
def firing_rate(binary_spikes, box_width, dt=0.000025):
"""Calculate the firing rate from binary spike data.
Parameters
----------
binary_spikes : np.array
A binary array representing spike occurrences.
box_width : float
The width of the box filter in seconds.
dt : float, optional
The temporal resolution (time step) in seconds. Default is 0.000025 seconds.
Returns
-------
rate : np.array
An array representing the firing rate at each time step.
"""
box = np.ones(int(box_width // dt))
box /= np.sum(box) * dt # Normalization of box kernel to an integral of 1
rate = np.convolve(binary_spikes, box, mode="same")
return rate
def powerspectrum(rate, dt):
"""Compute the power spectrum of a given firing rate.
This function calculates the power spectrum using the Welch method.
Parameters
----------
rate : np.array
An array of firing rates.
dt : float
The temporal resolution (time step) in seconds.
Returns
-------
frequency : np.array
An array of frequencies corresponding to the power values.
power : np.array
An array of power spectral density values.
"""
frequency, power = sig.welch(rate, fs=1/dt, nperseg=2**15, noverlap=2**14)
return frequency, power
def calculate_integral(frequency, power, point, delta):
"""
Calculate the integral around a single specified point.
Parameters
----------
frequency : np.array
An array of frequencies corresponding to the power values.
power : np.array
An array of power spectral density values.
point : float
The harmonic frequency at which to calculate the integral.
delta : float
Half-width of the range for integration around the point.
Returns
-------
integral : float
The calculated integral around the point.
local_mean : float
The local mean value (adjacent integrals).
"""
indices = (frequency >= point - delta) & (frequency <= point + delta)
integral = np.trapz(power[indices], frequency[indices])
left_indices = (frequency >= point - 5 * delta) & (frequency < point - delta)
right_indices = (frequency > point + delta) & (frequency <= point + 5 * delta)
l_integral = np.trapz(power[left_indices], frequency[left_indices])
r_integral = np.trapz(power[right_indices], frequency[right_indices])
local_mean = np.mean([l_integral, r_integral])
return integral, local_mean
def valid_integrals(integral, local_mean, threshold, point):
"""
Check if the integral exceeds the threshold compared to the local mean and
provide feedback on whether the given point is valid or not.
Parameters
----------
integral : float
The calculated integral around the point.
local_mean : float
The local mean value (adjacent integrals).
threshold : float
Threshold value to compare integrals with local mean.
point : float
The harmonic frequency point being evaluated.
Returns
-------
valid : bool
True if the integral exceeds the local mean by the threshold, otherwise False.
message : str
A message stating whether the point is valid or not.
"""
valid = integral > (local_mean * threshold)
if valid:
message = f"The point {point} is valid, as its integral exceeds the threshold."
else:
message = f"The point {point} is not valid, as its integral does not exceed the threshold."
return valid, message
def prepare_harmonics(frequencies, categories, num_harmonics, colors):
"""
Prepare harmonic frequencies and assign colors based on categories.
Parameters
----------
frequencies : list
Base frequencies to generate harmonics.
categories : list
Corresponding categories for the base frequencies.
num_harmonics : list
Number of harmonics for each base frequency.
colors : list
List of colors corresponding to the categories.
Returns
-------
points : list
A flat list of harmonic frequencies.
color_mapping : dict
A dictionary mapping each category to its corresponding color.
points_categories : dict
A mapping of categories to their harmonic frequencies.
"""
points_categories = {}
for idx, (freq, category) in enumerate(zip(frequencies, categories)):
points_categories[category] = [freq * (i + 1) for i in range(num_harmonics[idx])]
points = [p for harmonics in points_categories.values() for p in harmonics]
color_mapping = {category: colors[idx] for idx, category in enumerate(categories)}
return points, color_mapping, points_categories
def find_exceeding_points(frequency, power, points, delta, threshold):
"""
Find the points where the integral exceeds the local mean by a given threshold.
Parameters
----------
frequency : np.array
An array of frequencies corresponding to the power values.
power : np.array
An array of power spectral density values.
points : list
A list of harmonic frequencies to evaluate.
delta : float
Half-width of the range for integration around the point.
threshold : float
Threshold value to compare integrals with local mean.
Returns
-------
exceeding_points : list
A list of points where the integral exceeds the local mean by the threshold.
"""
exceeding_points = []
for point in points:
# Calculate the integral and local mean for the current point
integral, local_mean = calculate_integral(frequency, power, point, delta)
# Check if the integral exceeds the threshold
valid, message = valid_integrals(integral, local_mean, threshold, point)
if valid:
exceeding_points.append(point)
return exceeding_points
def plot_highlighted_integrals(frequency, power, exceeding_points, delta, threshold, color_mapping, points_categories):
"""
Plot the power spectrum and highlight integrals that exceed the threshold.
Parameters
----------
frequency : np.array
An array of frequencies corresponding to the power values.
power : np.array
An array of power spectral density values.
exceeding_points : list
A list of harmonic frequencies that exceed the threshold.
delta : float
Half-width of the range for integration around each point.
threshold : float
Threshold value to compare integrals with local mean.
color_mapping : dict
A dictionary mapping each category to its color.
points_categories : dict
A mapping of categories to lists of points.
Returns
-------
fig : matplotlib.figure.Figure
The created figure object with highlighted integrals.
"""
fig, ax = plt.subplots()
ax.plot(frequency, power) # Plot power spectrum
for point in exceeding_points:
integral, local_mean = calculate_integral(frequency, power, point, delta)
valid, _ = valid_integrals(integral, local_mean, threshold, point)
if valid:
# Define color based on the category of the point
color = next((c for cat, c in color_mapping.items() if point in points_categories[cat]), 'gray')
# Shade the region around the point where the integral was calculated
ax.axvspan(point - delta, point + delta, color=color, alpha=0.3, label=f'{point:.2f} Hz')
print(f"Integral around {point:.2f} Hz: {integral:.5e}")
# Define left and right boundaries of adjacent regions
left_boundary = frequency[np.where((frequency >= point - 5 * delta) & (frequency < point - delta))[0][0]]
right_boundary = frequency[np.where((frequency > point + delta) & (frequency <= point + 5 * delta))[0][-1]]
# Add vertical dashed lines at the boundaries of the adjacent regions
ax.axvline(x=left_boundary, color="k", linestyle="--")
ax.axvline(x=right_boundary, color="k", linestyle="--")
ax.set_xlim([0, 1200])
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power')
ax.set_title('Power Spectrum with Highlighted Integrals')
ax.legend()
return fig
### Data retrieval ###
datafolder = "../data"
example_file = os.path.join("..", "data", "2024-10-16-ad-invivo-1.nix")
dataset = rlx.Dataset(example_file)
sams = dataset.repro_runs("SAM")
sam = sams[2]
## Data for functions
df = sam.metadata["RePro-Info"]["settings"]["deltaf"][0][0]
stim = sam.stimuli[1]
potential, time = stim.trace_data("V-1")
spikes, _ = stim.trace_data("Spikes-1")
duration = stim.duration
dt = stim.trace_info("V-1").sampling_interval
### Apply Functions to calculate data ###
b = binary_spikes(spikes, duration, dt)
rate = firing_rate(b, box_width=0.05, dt=dt)
frequency, power = powerspectrum(b, dt)
### Important stuff ###
## Frequencies
eodf = stim.metadata[stim.name]["EODf"][0][0]
stimulus_frequency = eodf + df
AM = 50 # Hz
frequencies = [AM, eodf, stimulus_frequency]
categories = ["AM", "EODf", "Stimulus frequency"]
num_harmonics = [4, 2, 2]
colors = ["green", "orange", "red"]
delta = 2.5
threshold = 10
### Apply functions to make powerspectrum ###
integral, local = calculate_integral(frequency, power, eodf, delta)
valid = valid_integrals(integral, local, threshold, eodf)
points, color, categories = prepare_harmonics(frequencies, categories, num_harmonics, colors)
print(len(points))
exceeding = find_exceeding_points(frequency, power, points, delta, threshold)
print(len(exceeding))
## Plot power spectrum and highlight integrals
fig = plot_highlighted_integrals(frequency, power, points, delta, threshold, color, categories)
plt.show()

View File

@ -0,0 +1,162 @@
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
import rlxnix as rlx
from useful_functions import sam_data, sam_spectrum, calculate_integral, contrast_sorting, remove_poor
from tqdm import tqdm # Import tqdm for the progress bar
def load_files(file_path_pattern):
"""Load all files matching the pattern and remove poor quality files."""
all_files = glob.glob(file_path_pattern)
good_files = remove_poor(all_files)
return good_files
def process_sam_data(sam):
"""Process data for a single SAM and return necessary frequencies and powers."""
_, _, _, _, eodf, nyquist, stim_freq = sam_data(sam)
# Skip if stim_freq is NaN
if np.isnan(stim_freq):
return None
# Get power spectrum and frequency index for 1/2 EODf
freq, power = sam_spectrum(sam)
nyquist_idx = np.searchsorted(freq, nyquist)
# Get frequencies and powers before 1/2 EODf
freqs_before_half_eodf = freq[:nyquist_idx]
powers_before_half_eodf = power[:nyquist_idx]
# Get peak frequency and power
am_peak_f = freqs_before_half_eodf[np.argmax(powers_before_half_eodf)]
_, _, peak_power = calculate_integral(freq, power, am_peak_f)
return stim_freq, am_peak_f, peak_power
def plot_contrast_data(contrast_dict, file_tag, axs1, axs2):
"""Loop over all contrasts and plot AM Frequency and AM Power."""
for idx, contrast in enumerate(contrast_dict): # contrasts = keys of dict
ax1 = axs1[idx] # First figure (AM Frequency vs Stimulus Frequency)
ax2 = axs2[idx] # Second figure (AM Power vs Stimulus Frequency)
contrast_sams = contrast_dict[contrast]
# store all stim_freq and peak_power/nyquist_freq for this contrast
stim_freqs = []
am_freqs = []
peak_powers = []
# loop over all sams of one contrast
for sam in contrast_sams:
processed_data = process_sam_data(sam)
if processed_data is None:
continue
stim_freq, am_peak_f, peak_power = processed_data
stim_freqs.append(stim_freq)
am_freqs.append(am_peak_f)
peak_powers.append(peak_power)
# Plot in the first figure (AM Frequency vs Stimulus Frequency)
ax1.plot(stim_freqs, am_freqs, '-', label=file_tag)
ax1.set_title(f'Contrast {contrast}%')
ax1.grid(True)
ax1.legend(loc='upper right')
# Plot in the second figure (AM Power vs Stimulus Frequency)
ax2.plot(stim_freqs, peak_powers, '-', label=file_tag)
ax2.set_title(f'Contrast {contrast}%')
ax2.grid(True)
ax2.legend(loc='upper right')
def process_file(file, axs1, axs2):
"""Process a single file: extract SAMs and plot data for each contrast."""
dataset = rlx.Dataset(file)
sam_list = dataset.repro_runs('SAM')
# Extract the file tag (first part of the filename) for the legend
file_tag = '-'.join(os.path.basename(file).split('-')[0:4])
# Sort SAMs by contrast
contrast_dict = contrast_sorting(sam_list)
# Plot the data for each contrast
plot_contrast_data(contrast_dict, file_tag, axs1, axs2)
def loop_over_files(files, axs1, axs2):
"""Loop over all good files, process each file, and plot the data."""
for file in tqdm(files, desc="Processing files"):
process_file(file, axs1, axs2)
def main():
# Load files
file_path_pattern = '../data/16-10-24/*.nix'
good_files = load_files(file_path_pattern)
# Initialize figures
fig1, axs1 = plt.subplots(3, 1, constrained_layout=True, sharex=True) # For AM Frequency vs Stimulus Frequency
fig2, axs2 = plt.subplots(3, 1, constrained_layout=True, sharex=True) # For AM Power vs Stimulus Frequency
# Loop over files and process data
loop_over_files(good_files, axs1, axs2)
# Add labels to figures
fig1.supxlabel('Stimulus Frequency (df + EODf) [Hz]')
fig1.supylabel('AM Frequency [Hz]')
fig2.supxlabel('Stimulus Frequency (df + EODf) [Hz]')
fig2.supylabel('AM Power')
# Show plots
plt.show()
# Run the main function
if __name__ == '__main__':
main()
'''
Function that gets eodf and 1/2 eodf per contrast:
def calculate_mean_eodf(sams):
"""
Calculate mean EODf and mean 1/2 EODf for the given SAM data.
Args:
sams (list): List of SAM objects.
Returns:
mean_eodf (float): Mean EODf across all SAMs.
mean_half_eodf (float): Mean 1/2 EODf (Nyquist frequency) across all SAMs.
"""
eodfs = []
nyquists = []
for sam in sams:
_, _, _, _, eodf, nyquist, _ = sam_data(sam)
# Add to list only if valid
if not np.isnan(eodf):
eodfs.append(eodf)
nyquists.append(nyquist)
# Calculate mean EODf and 1/2 EODf
mean_eodf = np.mean(eodfs)
mean_half_eodf = np.mean(nyquists)
return mean_eodf, mean_half_eodf
'''
# TODO:
# display eodf values in plot for one cell, one intensity - integrate function for this
# lowpass with gaussian kernel for amplitude plot(0.5 sigma in frequency spectrum (dont filter too narrowly))
# fix legends (only for the cells that are being displayed)
# save figures
# plot remaining 3 plots, make 1 function for every option and put that in main code
# push files to git

View File

@ -0,0 +1,96 @@
import matplotlib.pyplot as plt
import numpy as np
import os
import rlxnix as rlx
from useful_functions import sam_data, sam_spectrum, calculate_integral, contrast_sorting
# close all open plots
plt.close('all')
def plot_am_vs_frequency_single_intensity(file, contrast=20):
"""
Plots AM Power vs Stimulus Frequency and Nyquist Frequency vs Stimulus Frequency for
one intensity and one cell (file).
Parameters:
file (str): Path to the file (one cell).
intensity (int): The intensity level (contrast) to filter by.
"""
# Load the dataset for the given file
dataset = rlx.Dataset(file)
# Get SAMs for the whole recording
sam_list = dataset.repro_runs('SAM')
# Extract the file tag (first part of the filename) for the legend
file_tag = '-'.join(os.path.basename(file).split('-')[0:4])
# Sort SAMs by contrast
contrast_dict = contrast_sorting(sam_list)
# Get the SAMs for 20% contrast
sams = contrast_dict[contrast]
# Create a figure with 1 row and 2 columns
fig, axs = plt.subplots(2, 1, layout='constrained')
# Store all stim_freq, peak_power, and am_freq for the given contrast
stim_freqs = []
peak_powers = []
am_freqs = []
# Loop over all SAMs of the specified contrast
for sam in sams:
# Get stim_freq for each SAM
_, _, _, _, eodf, nyquist, stim_freq = sam_data(sam)
# Skip over empty SAMs
if np.isnan(stim_freq):
continue
# Get power spectrum from one SAM
freq, power = sam_spectrum(sam)
# get index of 1/2 eodf frequency
nyquist_idx = np.searchsorted(freq, nyquist)
# get frequencies until 1/2 eodf and powers for those frequencies
freqs_before_half_eodf = freq[:nyquist_idx]
powers_before_half_eodf = power[:nyquist_idx]
# Get the frequency of the highest peak before 1/2 EODf
am_peak_f = freqs_before_half_eodf[np.argmax(powers_before_half_eodf)]
# Get the power of the highest peak before 1/2 EODf
_, _, peak_power = calculate_integral(freq, power, am_peak_f)
# Collect data for plotting
stim_freqs.append(stim_freq)
peak_powers.append(peak_power)
am_freqs.append(am_peak_f)
# Plot AM Power vs Stimulus Frequency (first column)
ax = axs[0]
ax.plot(stim_freqs, am_freqs, '-')
ax.set_ylabel('AM Frequency [Hz]')
ax.grid(True)
# Plot AM Frequency vs Stimulus Frequency (second column)
ax = axs[1]
ax.plot(stim_freqs, peak_powers, '-')
ax.set_ylabel('AM Power')
ax.grid(True)
# Figure settings
fig.suptitle(f"Cell: {file_tag}, Contrast: {contrast}%")
fig.supxlabel("Stimulus Frequency (df + EODf) [Hz]")
plt.show()
# Call function
file = '../data/16-10-24/2024-10-16-ad-invivo-1.nix'
# Call the function to plot the data for one intensity and one cell
plot_am_vs_frequency_single_intensity(file)

View File

@ -1,154 +0,0 @@
import rlxnix as rlx
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.signal import welch
# close all currently open figures
plt.close('all')
'''FUNCTIONS'''
def plot_vt_spikes(t, v, spike_t):
fig = plt.figure(figsize=(5, 2.5))
# alternative to ax = axs[0]
ax = fig.add_subplot()
# plot vt diagram
ax.plot(t[t<0.1], v[t<0.1])
# plot spikes into vt diagram, at max V
ax.scatter(spike_t[spike_t<0.1], np.ones_like(spike_t[spike_t<0.1]) * np.max(v))
plt.show()
def scatter_plot(colormap, stimuli_list, stimulus_count):
'''plot scatter plot for one sam with all 3 stims'''
fig = plt.figure()
ax = fig.add_subplot()
ax.eventplot(stimuli_list, colors=colormap)
ax.set_xlabel('Spike Times [ms]')
ax.set_ylabel('Loop #')
ax.set_yticks(range(stimulus_count))
ax.set_title('Spikes of SAM 3')
plt.show()
# create binary array with ones for spike times
def binary_spikes(spike_times, duration , dt):
'''Converts spike times to binary representation
Params
------
spike_times: np.array
spike times
duration: float
trial duration
dt: float
temporal resolution
Returns
--------
binary: np.array
The binary representation of the spike times
'''
binary = np.zeros(int(duration//dt)) # // is truncated division, returns number w/o decimals, same as np.round
spike_indices = np.asarray(np.round(spike_times//dt), dtype=int)
binary[spike_indices] = 1
return binary
# function to plot psth
def firing_rates(binary_spikes, box_width=0.01, dt=0.000025):
box = np.ones(int(box_width // dt))
box /= np.sum(box * dt) # normalize box kernel w interal of 1
rate = np.convolve(binary_spikes, box, mode='same')
return rate
def power_spectrum(rate, dt):
f, p = welch(rate, fs = 1./dt, nperseg=2**16, noverlap=2**15)
# algorithm makes rounding mistakes, we want to calc many spectra and take mean of those
# nperseg: length of segments in # datapoints
# noverlap: # datapoints that overlap in segments
return f, p
def power_spectrum_plot(f, p):
# plot power spectrum
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(freq, power)
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Power [1/Hz]')
ax.set_xlim(0, 1000)
plt.show()
'''IMPORT DATA'''
datafolder = '../data' #./ wo ich gerade bin; ../ eine ebene höher; ../../ zwei ebenen höher
example_file = os.path.join('..', 'data', '2024-10-16-ac-invivo-1.nix')
'''EXTRACT DATA'''
dataset = rlx.Dataset(example_file)
# get sams
sams = dataset.repro_runs('SAM')
sam = sams[2]
# get potetial over time (vt curve)
potential, time = sam.trace_data('V-1')
# get spike times
spike_times, _ = sam.trace_data('Spikes-1')
# get stim count
stim_count = sam.stimulus_count
# extract spike times of all 3 loops of current sam
stimuli = []
for i in range(stim_count):
# get stim i from sam
stim = sam.stimuli[i]
potential_stim, time_stim = stim.trace_data('V-1')
# get spike_times
spike_times_stim, _ = stim.trace_data('Spikes-1')
stimuli.append(spike_times_stim)
eodf = stim.metadata[stim.name]['EODF'][0][0]
df = stim.metadata['RePro-Info']['settings']['deltaf'][0][0]
stimulus_freq = df + eodf
'''PLOT'''
# create colormap
colors = plt.cm.prism(np.linspace(0, 1, stim_count))
# timeline of whole rec
dataset.plot_timeline()
# voltage and spikes of current sam
plot_vt_spikes(time, potential, spike_times)
# spike times of all loops
scatter_plot(colors, stimuli, stim_count)
'''POWER SPECTRUM'''
# define variables for binary spikes function
spikes, _ = stim.trace_data('Spikes-1')
ti = stim.trace_info('V-1')
dt = ti.sampling_interval
duration = stim.duration
### spectrum
# vector with binary values for wholes length of stim
binary = binary_spikes(spikes, duration, dt)
# calculate firing rate
rate = firing_rates(binary, 0.01, dt) # box width of 10 ms
# plot psth or whatever
# plt.plot(time_stim, rate)
# plt.show()
freq, power = power_spectrum(binary, dt)
power_spectrum_plot(freq, power)
### TODO:
# then loop over sams/dfs, all stims, intensities
# when does stim start in eodf/ at which phase and how does that influence our signal --> alignment problem: egal wenn wir spectren haben
# we want to see peaks at phase locking to own and stim frequency, and at amp modulation frequency

View File

@ -1,26 +1,45 @@
import glob import glob
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import os
import rlxnix as rlx import rlxnix as rlx
import scipy as sp
import time
import useful_functions as f import useful_functions as f
from matplotlib.lines import Line2D
from tqdm import tqdm
# tatsächliche Power der peaks benutzen # plot the tuning curves for all cells y/n
single_plots = True
# all files we want to use # all files we want to use
files = glob.glob("../data/2024-10-*.nix") files = glob.glob("../data/2024-10-*.nix")
#EODf file for either day
eodf_file_w = glob.glob('../data/EOD_only/*-16*.nix')[0]
eodf_file_m = glob.glob('../data/EOD_only/*-21*.nix')[0]
# get only the good and fair filepaths # get only the good and fair filepaths
new_files = f.remove_poor(files) new_files = f.remove_poor(files)
#get the filenames as labels for plotting
labels = [os.path.splitext(os.path.basename(file))[0] for file in new_files]
# loop over all the good files # dict for all the different contrasts
for file in new_files: contrast_files = {20 : {'power' :[], 'freq' : []},
10 : {'power' :[], 'freq' : []},
5 : {'power' :[], 'freq' : []}}
norm_contrast_files = {20 : {'power' :[], 'freq' : []},
10 : {'power' :[], 'freq' : []},
5 : {'power' :[], 'freq' : []}}
# loop over all the good files
for u, file in tqdm(enumerate(new_files), total = len(new_files)):
#use correct eodf file
if "-16" in file:
orig_eodf = f.true_eodf(eodf_file_w)
else:
orig_eodf = f.true_eodf(eodf_file_m)
#define lists
contrast_frequencies = [] contrast_frequencies = []
contrast_powers = [] contrast_powers = []
# load a file # load a file
@ -30,78 +49,145 @@ for file in new_files:
# get arrays for frequnecies and power # get arrays for frequnecies and power
stim_frequencies = np.zeros(len(sams)) stim_frequencies = np.zeros(len(sams))
peak_powers = np.zeros_like(stim_frequencies) peak_powers = np.zeros_like(stim_frequencies)
# loop over all sams contrast_sams = f.contrast_sorting(sams)
# dictionary for the contrasts
contrast_sams = {20 : [], eodfs = []
10 : [],
5 : []}
# loop over all sams
for sam in sams:
# get the contrast
avg_dur, contrast, _, _, _, _, _ = f.sam_data(sam)
# check for valid trails
if np.isnan(contrast):
continue
elif sam.stimulus_count < 3: #aborted trials
continue
elif avg_dur < 1.7:
continue
else:
contrast = int(contrast) # get integer of contrast
# sort them accordingly
if contrast == 20:
contrast_sams[20].append(sam)
if contrast == 10:
contrast_sams[10].append(sam)
if contrast == 5:
contrast_sams[5].append(sam)
else:
continue
# loop over the contrasts # loop over the contrasts
for key in contrast_sams: for key in contrast_sams:
stim_frequencies = np.zeros(len(contrast_sams[key])) stim_frequencies = np.zeros(len(contrast_sams[key]))
norm_stim_frequencies = np.zeros_like(stim_frequencies)
peak_powers = np.zeros_like(stim_frequencies) peak_powers = np.zeros_like(stim_frequencies)
for i, sam in enumerate(contrast_sams[key]): for i, sam in enumerate(contrast_sams[key]):
# get stimulus frequency and stimuli # get stimulus frequency and stimuli
_, _, _, _, _, _, stim_frequency = f.sam_data(sam) _, _, _, _, eodf, _, stim_frequency = f.sam_data(sam)
stimuli = sam.stimuli sam_frequency, sam_power = f.sam_spectrum(sam)
# lists for the power spectra
frequencies = []
powers = []
# loop over the stimuli
for stimulus in stimuli:
# get the powerspectrum for each stimuli
frequency, power = f.power_spectrum(stimulus)
# append the power spectrum data
frequencies.append(frequency)
powers.append(power)
#average over the stimuli
sam_frequency = np.mean(frequencies, axis = 0)
sam_power = np.mean(powers, axis = 0)
# detect peaks # detect peaks
integral, surroundings, peak_power = f.calculate_integral(sam_frequency, _, _, peak_powers[i] = f.calculate_integral(sam_frequency,
sam_power, stim_frequency) sam_power, stim_frequency)
peak_powers[i] = peak_power
# add the current stimulus frequency # add the current stimulus frequency
stim_frequencies[i] = stim_frequency stim_frequencies[i] = stim_frequency
norm_stim_frequencies[i] = stim_frequency - orig_eodf
eodfs.append(eodf)
# replae zeros with NaN # replae zeros with NaN
peak_powers = np.where(peak_powers == 0, np.nan, peak_powers) peak_powers = np.where(peak_powers == 0, np.nan, peak_powers)
contrast_frequencies.append(stim_frequencies) contrast_frequencies.append(stim_frequencies)
contrast_powers.append(peak_powers) contrast_powers.append(peak_powers)
if key == 20:
fig, ax = plt.subplots(layout = 'constrained') contrast_files[20]['freq'].append(stim_frequencies)
ax.plot(contrast_frequencies[0], contrast_powers[0]) contrast_files[20]['power'].append(peak_powers)
ax.plot(contrast_frequencies[1], contrast_powers[1]) norm_contrast_files[20]['freq'].append(norm_stim_frequencies)
ax.plot(contrast_frequencies[2], contrast_powers[2]) norm_contrast_files[20]['power'].append(peak_powers)
ax.set_xlabel('stimulus frequency [Hz]') elif key == 10:
ax.set_ylabel(r' power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]') contrast_files[10]['freq'].append(stim_frequencies)
ax.set_title(f"{file}") contrast_files[10]['power'].append(peak_powers)
norm_contrast_files[10]['freq'].append(norm_stim_frequencies)
norm_contrast_files[10]['power'].append(peak_powers)
else:
contrast_files[5]['freq'].append(stim_frequencies)
contrast_files[5]['power'].append(peak_powers)
norm_contrast_files[5]['freq'].append(norm_stim_frequencies)
norm_contrast_files[5]['power'].append(peak_powers)
curr_eodf = np.mean(eodfs)
if single_plots == True:
# one cell with all contrasts in one subplot
fig, ax = plt.subplots()
ax.plot(contrast_frequencies[0], contrast_powers[0])
ax.plot(contrast_frequencies[1], contrast_powers[1])
if contrast_frequencies and contrast_frequencies[-1].size == 0:
if contrast_frequencies and contrast_frequencies[-2].size == 0:
ax.set_xlim(0,2000)
else:
ax.set_xlim(0,np.max(contrast_frequencies[-2]))
else:
ax.plot(contrast_frequencies[2], contrast_powers[2])
ax.set_xlim(0,np.max(contrast_frequencies[-1]))
ax.axvline(orig_eodf, color = 'black',linestyle = 'dashed', alpha = 0.8)
ax.axvline(2*curr_eodf, color = 'black', linestyle = 'dotted', alpha = 0.8)
ax.set_ylim(0, 0.00014)
ax.set_xlabel('stimulus frequency [Hz]')
ax.set_ylabel(r' power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]')
ax.set_title(f"{file}")
fig.legend(labels = ['20 % contrast', '10 % contrast','5 % contrast','EODf of awake fish', '1st harmonic of current EODf' ], loc = 'lower center', ncol = 3)
plt.tight_layout(rect=[0, 0.06, 1, 1])
plt.savefig(f'../results/tuning_curve{labels[u]}.svg')
#one cell with the contrasts in different subplots
fig, axs = plt.subplots(1, 3, figsize = [10,6], sharex = True, sharey = True)
for p, key in enumerate(contrast_files):
ax = axs[p]
ax.plot(contrast_files[key]['freq'][-1],contrast_files[key]['power'][-1])
ax.set_title(f"{key}")
ax.axvline(orig_eodf, color = 'black',linestyle = 'dashed')
ax.axvline(2*curr_eodf, color = 'darkblue', linestyle = 'dotted', alpha = 0.8)
if p == 0:
ax.set_ylabel(r'power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]', fontsize=12)
fig.supxlabel('stimulus frequency [Hz]', fontsize=12)
fig.suptitle(f'{labels[u]}')
fig.legend(labels = ['power of stimulus peak', 'EODf of awake fish','1st harmonic of current EODf'], loc = 'lower center', bbox_to_anchor=(0.5, 0.05), ncol = 3)
plt.tight_layout(rect=[0, 0.06, 1, 1])
plt.savefig(f'../results/contrast_tuning{labels[u]}.svg')
cmap = plt.get_cmap('viridis')
colors = cmap(np.linspace(0, 1, len(new_files)))
plt.close('all')
if len(new_files) < 10:
lines = []
labels_legend = []
fig, axs = plt.subplots(1, 3, figsize = [10,6], sharex = True, sharey = True)
for p, key in enumerate(contrast_files):
ax = axs[p]
for i in range(len(contrast_files[key]['power'])):
line, = ax.plot(contrast_files[key]['freq'][i],contrast_files[key]['power'][i], label = labels[i], color = colors[i])
ax.set_title(f"{key}")
ax.axvline(orig_eodf, color = 'black',linestyle = 'dashed')
if p == 0:
lines.append(line)
labels_legend.append(labels[i])
fig.supxlabel('stimulus frequency [Hz]', fontsize=12)
fig.supylabel(r'power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]', fontsize=12)
# Create a single legend beneath the plots with 3 columns
lines.append(Line2D([0], [0], color='black', linestyle='--')) # Custom line for the legend
labels_legend.append("Awake fish EODf") # Custom label
fig.legend(lines, labels_legend, loc='upper center', ncol=3, fontsize=10)
plt.tight_layout(rect=[0, 0, 1, 0.85]) # Adjust layout to make space for the legend
if "-16" in new_files[-1]:
plt.savefig('../results/tuning_curves_10_16.svg')
elif "-21" in new_files[0]:
plt.savefig('../results/tuning_curves_10_21.svg')
else:
for o in range(2):
lines = []
labels_legend = []
fig, axs = plt.subplots(1, 3, figsize = [10,6], sharex = True, sharey = True)
for p, key in enumerate(norm_contrast_files):
ax = axs[p]
for i in range(len(norm_contrast_files[key]['power'])):
line, = ax.plot(norm_contrast_files[key]['freq'][i],norm_contrast_files[key]['power'][i], label = labels[i], color = colors[i])
ax.set_title(f"{key}")
ax.axvline(0, color = 'black',linestyle = 'dashed')
if p == 0:
lines.append(line)
labels_legend.append(labels[i])
fig.supylabel(r'power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]', fontsize=12)
# Create a single legend beneath the plots with 3 columns
lines.append(Line2D([0], [0], color='black', linestyle='--')) # Custom line for the legend
labels_legend.append("Awake fish EODf") # Custom label
fig.legend(lines, labels_legend, loc='upper center', ncol=3, fontsize=10)
plt.tight_layout(rect=[0, 0, 1, 0.82]) # Adjust layout to make space for the legend
if o == 0:
ax.set_xlim(-600, 2100)
fig.supxlabel('stimulus frequency [Hz]', fontsize=12)
plt.savefig('../results/tuning_curves_norm.svg')
else:
ax.set_xlim(-600, 600)
fig.supxlabel(' relative stimulus frequency [Hz]', fontsize=12)
plt.savefig('../results/tuning_curves_norm_zoom.svg')
#plt.close('all')

View File

@ -275,7 +275,7 @@ def extract_stim_data(stimulus):
stim_dur = stimulus.duration stim_dur = stimulus.duration
# calculates the amplitude modulation # calculates the amplitude modulation
amp_mod, ny_freq = AM(eodf, stim_freq) amp_mod, ny_freq = AM(eodf, stim_freq)
return amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq return amplitude, df, eodf, stim_freq, stim_dur, amp_mod, ny_freq
def find_exceeding_points(frequency, power, points, delta, threshold): def find_exceeding_points(frequency, power, points, delta, threshold):
""" """

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 42 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 43 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 41 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 46 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 47 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 41 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 41 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 48 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 50 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 48 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 48 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 52 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 38 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 39 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 37 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 42 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 43 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 38 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 38 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 42 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 44 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 38 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 40 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 44 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 71 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 99 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 99 KiB