gpgrewe2024/code/test.py

253 lines
7.3 KiB
Python

import glob
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import rlxnix as rlx
from IPython import embed
from scipy.signal import welch
import useful_functions as f
def binary_spikes(spike_times, duration, dt):
"""
Converts the spike times to a binary representations
Parameters
----------
spike_times : np.array
The spike times.
duration : float
The trial duration:
dt : float
The temporal resolution.
Returns
-------
binary : np.array
The binary representation of the spike train.
"""
binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential
spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices
binary[spike_indices] = 1 # put the indices into binary
return binary
def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01):
box = np.ones(int(box_width // dt))
box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one
rate = np.convolve(binary_spikes, box, mode = 'same')
return rate
def power_spectrum(rate, dt):
freq, power = welch(rate, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
return freq, power
def extract_stim_data(stimulus):
'''
extracts all necessary metadata for each stimulus
Parameters
----------
stimulus : Stimulus object or rlxnix.base.repro module
The stimulus from which the data is needed.
Returns
-------
amplitude : float
The relative signal amplitude in percent.
df : float
Distance of the stimulus to the current EODf.
eodf : float
Current EODf.
stim_freq : float
The total stimulus frequency (EODF+df).
amp_mod : float
The current amplitude modulation.
ny_freq : float
The current nyquist frequency.
'''
# extract metadata
# the stim.name adjusts the first key as it changes with every stimulus
amplitude = stimulus.metadata[stimulus.name]['Contrast'][0][0]
df = stimulus.metadata[stimulus.name]['DeltaF'][0][0]
eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0])
stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0])
# calculates the amplitude modulation
amp_mod, ny_freq = AM(eodf, stim_freq)
return amplitude, df, eodf, stim_freq, amp_mod, ny_freq
def AM(EODf, stimulus):
"""
Calculates the Amplitude Modulation and Nyquist frequency
Parameters
----------
EODf : float or int
The current EODf.
stimulus : float or int
The absolute frequency of the stimulus.
Returns
-------
AM : float
The amplitude modulation resulting from the stimulus.
nyquist : float
The maximum frequency possible to resolve with the EODf.
"""
nyquist = EODf * 0.5
AM = np.mod(stimulus, nyquist)
return AM, nyquist
def remove_poor(files):
"""
Removes poor datasets from the set of files for analysis
Parameters
----------
files : list
list of files.
Returns
-------
good_files : list
list of files without the ones with the label poor.
"""
# create list for good files
good_files = []
# loop over files
for i in range(len(files)):
# print(files[i])
# load the file (takes some time)
data = rlx.Dataset(files[i])
# get the quality
quality = str.lower(data.metadata["Recording"]["Recording quality"][0][0])
# check the quality
if quality != "poor":
# if its good or fair add it to the good files
good_files.append(files[i])
return good_files
def sam_data(sam):
'''
Gets metadata for each SAM
Parameters
----------
sam : ReproRun object
The sam the metdata should be extracted from.
Returns
-------
sam_amp : float
amplitude in percent, relative to the fish amplitude.
sam_am : float
Amplitude modulation frequency.
sam_df : float
Difference from the stimulus to the current fish eodf.
sam_eodf : float
The current EODf.
sam_nyquist : float
The Nyquist frequency of the EODf.
sam_stim : float
The stimulus frequency.
'''
# create lists for the values we want
amplitudes = []
dfs = []
eodfs = []
stim_freqs = []
amp_mods = []
ny_freqs = []
# get the stimuli
stimuli = sam.stimuli
# loop over the stimuli
for stim in stimuli:
amplitude, df, eodf, stim_freq, amp_mod, ny_freq = extract_stim_data(stim)
amplitudes.append(amplitude)
dfs.append(df)
eodfs.append(eodf)
stim_freqs.append(stim_freq)
amp_mods.append(amp_mod)
ny_freqs.append(ny_freq)
# get the means
sam_amp = np.mean(amplitudes)
sam_am = np.mean(amp_mods)
sam_df = np.mean(dfs)
sam_eodf = np.mean(eodfs)
sam_nyquist = np.mean(ny_freqs)
sam_stim = np.mean(stim_freqs)
return sam_amp, sam_am,sam_df, sam_eodf, sam_nyquist, sam_stim
#find example data
datafolder = "../../data"
example_file = datafolder + "/" + "2024-10-16-ad-invivo-1.nix"
data_files = glob.glob("../../data/*.nix")
#load dataset
dataset = rlx.Dataset(example_file)
# find all sams
sams = dataset.repro_runs('SAM')
sam = sams[2] # our example sam
potential,time = sam.trace_data("V-1") #membrane potential
spike_times, _ = sam.trace_data('Spikes-1') #spike times
df = sam.metadata['RePro-Info']['settings']['deltaf'][0][0] #find df in metadata
amp = sam.metadata['RePro-Info']['settings']['contrast'][0][0] * 100 #find amplitude in metadata
#figure for a quick plot
fig = plt.figure(figsize = (5, 2.5))
ax = fig.add_subplot()
ax.plot(time[time < 0.1], potential[time < 0.1]) # plot the membrane potential in 0.1s
ax.scatter(spike_times[spike_times < 0.1],
np.ones_like(spike_times[spike_times < 0.1]) * np.max(potential)) #plot teh spike times on top
plt.show()
plt.close()
sam_amp, sam_am,sam_df, sam_eodf, sam_nyquist, sam_stim = f.sam_data(sam)
# # get all the stimuli
# stims = sam.stimuli
# # empty list for the spike times
# spikes = []
# #spikes2 = np.array(range(len(stims)))
# # loop over the stimuli
# for stim in stims:
# # get the spike times
# spike, _ = stim.trace_data('Spikes-1')
# # append the first 100ms to spikes
# spikes.append(spike[spike < 0.1])
# # get stimulus duration
# duration = stim.duration
# ti = stim.trace_info("V-1")
# dt = ti.sampling_interval # get the stimulus interval
# bin_spikes = binary_spikes(spike, duration, dt) #binarize the spike_times
# print(len(bin_spikes))
# pot,tim= stim.trace_data("V-1") #membrane potential
# rate = firing_rate(bin_spikes, dt = dt)
# print(np.mean(rate))
# fig, [ax1, ax2] = plt.subplots(1, 2,layout = 'constrained')
# ax1.plot(tim,rate)
# ax1.set_ylim(0,600)
# ax1.set_xlim(0, 0.04)
# freq, power = power_spectrum(rate, dt)
# ax2.plot(freq,power)
# ax2.set_xlim(0,1000)
# plt.close()
# if stim == stims[-1]:
# amplitude, df, eodf, stim_freq = extract_stim_data(stim)
# print(amplitude, df, eodf, stim_freq)
# # make an eventplot
# fig = plt.figure(figsize = (5, 3), layout = 'constrained')
# ax = fig.add_subplot()
# ax.eventplot(spikes, linelength = 0.8)
# ax.set_xlabel('time [ms]')
# ax.set_ylabel('loop no.')