gpgrewe2024/code/test.py
2024-10-18 14:47:51 +02:00

101 lines
3.2 KiB
Python

import glob
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import rlxnix as rlx
from IPython import embed
from scipy.signal import welch
def binary_spikes(spike_times, duration, dt):
"""
Converts the spike times to a binary representations
Parameters
----------
spike_times : np.array
The spike times.
duration : float
The trial duration:
dt : float
The temporal resolution.
Returns
-------
binary : np.array
The binary representation of the spike train.
"""
binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential
spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices
binary[spike_indices] = 1 # put the indices into binary
return binary
def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01):
box = np.ones(int(box_width // dt))
box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one
rate = np.convolve(binary_spikes, box, mode = 'same')
return rate
def power_spectrum(rate, dt):
freq, power = welch(rate, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
return freq, power
#find example data
datafolder = "../data"
example_file = datafolder + "/" + "2024-10-16-ad-invivo-1.nix"
#load dataset
dataset = rlx.Dataset(example_file)
# find all sams
sams = dataset.repro_runs('SAM')
sam = sams[2] # our example sam
potential,time = sam.trace_data("V-1") #membrane potential
spike_times, _ = sam.trace_data('Spikes-1') #spike times
df = sam.metadata['RePro-Info']['settings']['deltaf'][0][0] #find df in metadata
amp = sam.metadata['RePro-Info']['settings']['contrast'][0][0] * 100 #find amplitude in metadata
#figure for a quick plot
fig = plt.figure(figsize = (5, 2.5))
ax = fig.add_subplot()
ax.plot(time[time < 0.1], potential[time < 0.1]) # plot the membrane potential in 0.1s
ax.scatter(spike_times[spike_times < 0.1],
np.ones_like(spike_times[spike_times < 0.1]) * np.max(potential)) #plot teh spike times on top
plt.show()
plt.close()
# get all the stimuli
stims = sam.stimuli
# empty list for the spike times
spikes = []
#spikes2 = np.array(range(len(stims)))
# loop over the stimuli
for stim in stims:
# get the spike times
spike, _ = stim.trace_data('Spikes-1')
# append the first 100ms to spikes
spikes.append(spike[spike < 0.1])
# get stimulus duration
duration = stim.duration
ti = stim.trace_info("V-1")
dt = ti.sampling_interval # get the stimulus interval
bin_spikes = binary_spikes(spike, duration, dt) #binarize the spike_times
print(len(bin_spikes))
pot,tim= stim.trace_data("V-1") #membrane potential
rate = firing_rate(bin_spikes, dt = dt)
print(np.mean(rate))
fig, [ax1, ax2] = plt.subplots(1, 2,layout = 'constrained')
ax1.plot(tim,rate)
ax1.set_ylim(0,600)
ax1.set_xlim(0, 0.04)
freq, power = power_spectrum(rate, dt)
ax2.plot(freq,power)
ax2.set_xlim(0,1000)
# make an eventplot
fig = plt.figure(figsize = (5, 3), layout = 'constrained')
ax = fig.add_subplot()
ax.eventplot(spikes, linelength = 0.8)
ax.set_xlabel('time [ms]')
ax.set_ylabel('loop no.')
plt.show()