gpgrewe2024/code/plot_functions.py
2024-10-28 17:32:19 +01:00

155 lines
4.8 KiB
Python

'''This script contains all functions for various plots that could be relevant
for the presentation or protocol of the Grewe GP 2024'''
import os
import matplotlib.pyplot as plt
import numpy as np
import rlxnix as rlx
from useful_functions import power_spectrum
import sys
'''IMPORT DATA'''
datafolder = '../data' #./ wo ich gerade bin; ../ eine ebene höher; ../../ zwei ebenen höher
example_file = os.path.join('..', 'data', '2024-10-16-ac-invivo-1.nix')
'''EXTRACT DATA'''
dataset = rlx.Dataset(example_file)
# get sams
sams = dataset.repro_runs('SAM')
sam = sams[2]
# get stims
stimulus = sam.stimuli[-1]
stim_count = sam.stimulus_count
'''PLOTS'''
# create colormap
colors = plt.cm.prism(np.linspace(0, 1, stim_count))
# plot timeline of whole rec
dataset.plot_timeline()
# plot voltage over time for whole trace
def plot_vt_spikes(t, v, spike_t):
fig = plt.figure(figsize=(5, 2.5))
# alternative to ax = axs[0]
ax = fig.add_subplot()
# plot vt diagram
ax.plot(t[t<0.1], v[t<0.1])
# plot spikes into vt diagram, at max V
ax.scatter(spike_t[spike_t<0.1], np.ones_like(spike_t[spike_t<0.1]) * np.max(v))
plt.show()
# plot scatter plot for one sam with all 3 stims
def scatter_plot(colormap, stimuli_list, stimulus_count):
fig = plt.figure()
ax = fig.add_subplot()
ax.eventplot(stimuli_list, colors=colormap)
ax.set_xlabel('Spike Times [ms]')
ax.set_ylabel('Loop #')
ax.set_yticks(range(stimulus_count))
ax.set_title('Spikes of SAM 3')
plt.show()
# calculate power spectrum
freq, power = power_spectrum(stimulus)
# plot power spectrum
def power_spectrum_plot(f, p):
# plot power spectrum
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(freq, power)
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Power [1/Hz]')
ax.set_xlim(0, 1000)
plt.show()
# DIANAS POWER SPECTRUM PLOT
functions_path = r"C:\Users\diana\OneDrive - UT Cloud\Master\GPs\GP1_Grewe\Projekt\gpgrewe2024\code"
sys.path.append(functions_path)
import useful_functions as u
import matplotlib.ticker as ticker
import matplotlib.patches as mpatches
import matplotlib.cm as cm
def float_formatter(x, _):
"""Format the y-axis values as floats with a specified precision."""
return f'{x:.5f}'
def plot_highlighted_integrals(ax, frequency, power, points, nyquist, true_eodf, color_mapping, points_categories, delta=2.5):
"""
Highlights integrals on the existing axes of the power spectrum for a given dataset.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the highlighted integrals.
frequency : np.array
An array of frequencies corresponding to the power values.
power : np.array
An array of power spectral density values.
points : list
A list of harmonic frequencies to check and highlight.
delta : float
Half-width of the range for integration around each point.
color_mapping : dict
A dictionary mapping each category to its color.
points_categories : dict
A mapping of categories to lists of points.
Returns
-------
None
"""
# Define color mappings for specific categories
category_colors = {
"AM": "#ff7f0e",
"Nyquist": "#2ca02c",
"EODf": "#d62728",
"Stimulus": "#9467bd",
"EODf (awake fish)": "#8c564b"
}
# Plot the power spectrum on the provided axes
for point in points:
# Identify the category for the current point
point_category = next((cat for cat, pts in points_categories.items() if point in pts), "Unknown")
# Assign color based on category, or default to grey if unknown
color = color_mapping.get(point_category, 'gray')
# Calculate the integral and check validity
integral, local_mean = u.calculate_integral_2(frequency, power, point)
valid = u.valid_integrals(integral, local_mean, point)
if valid:
# Highlight valid points with a shaded region
ax.axvspan(point - delta, point + delta, color=color, alpha=0.35, label=f'{point_category}')
ax.plot(frequency, power, color="#1f77b4", linewidth=1.5)
# Use the category colors for 'Nyquist' and 'EODf' lines
ax.axvline(nyquist, color=category_colors.get("Nyquist", "#2ca02c"), linestyle="--")
ax.axvline(true_eodf, color=category_colors.get("EODf (awake fish)", "#8c564b"), linestyle="--")
# Set plot limits and labels
ax.set_xlim([0, 1200])
ax.set_ylim([0, 6e-5])
ax.set_xlabel('Frequency (Hz)', fontsize=12)
ax.set_ylabel(r'Power [$\frac{\mathrm{Hz^2}}{\mathrm{Hz}}$]', fontsize=12)
#ax.set_title('Power Spectrum with highlighted Integrals', fontsize=14)
# Apply float formatting to the y-axis
ax.yaxis.set_major_formatter(ticker.FuncFormatter(float_formatter))