fishbook/fishbook/frontend/util.py

227 lines
7.3 KiB
Python

import numpy as np
import os
import subprocess
from scipy.optimize import curve_fit
from IPython import embed
def spike_times_to_rate(spike_times, time_axis, kernel_width=0.005):
"""Convert spike times to a rate by means of kernel convolution. A Gaussian kernel of the desired width is used.
Args:
spike_times (numpy.ndarray): the spike times in seconds.
time_axis (np.ndarray): the time axis with a proper resolution and extent. (in seconds)
kernel_width (float, optional): the standard deviation of the Gausian kernel. Defaults to 0.005.
Returns:
np.ndarray: the firing rate in Hz.
"""
dt = np.mean(np.diff(time_axis))
binary = np.zeros(time_axis.shape)
spike_indices = ((spike_times - time_axis[0]) / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return rate
def safe_get_val(dictionary:dict, key, default=None):
return dictionary[key] if key in dictionary.keys() else default
def results_check(results, id, text="ID"):
if len(results) == 0:
raise ValueError("%s %s does not exist!" % (text, id))
elif len(results) > 1:
raise ValueError("%s %s is not unique!" % (text, id))
def zero_crossings(x, t, interpolate=False):
"""get the times at which a signal x
Args:
x ([type]): [description]
t ([type]): [description]
interpolate (bool, optional): [description]. Defaults to False.
Returns:
[type]: [description]
"""
dt = t[1] - t[0]
x_shift = np.roll(x, 1)
x_shift[0] = 0.0
xings = np.where((x >= 0 ) & (x_shift < 0))[0]
crossings = np.zeros(len(xings))
if interpolate:
for i, tf in enumerate(xings):
if x[tf] > 0.001:
m = (x[tf] - x[tf-1])/dt
crossings[i] = t[tf] - x[tf]/m
elif x[tf] < -0.001:
m = (x[tf + 1] - x[tf]) / dt
crossings[i] = t[tf] - x[tf]/m
else:
crossings[i] = t[tf]
else:
crossings = t[xings]
return crossings
def unzip_if_needed(dataset, tracename='trace-1.raw'):
"""[summary]
Args:
dataset ([type]): [description]
tracename (str, optional): [description]. Defaults to 'trace-1.raw'.
"""
file_name = os.path.join(dataset, tracename)
if os.path.exists(file_name):
return
if os.path.exists(file_name + '.gz'):
print("\tunzip: %s" % tracename)
subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])
def gaussian_kernel(sigma, dt):
"""Creates a gaussian kernel with the integral of one.
Args:
sigma ([type]): [description]
dt ([type]): [description]
Returns:
[type]: [description]
"""
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
class BoltzmannFit:
"""
Class representing a fit of a Boltzmann function to some data.
"""
def __init__(self, xvalues: np.ndarray, yvalues: np.ndarray, initial_params=None):
"""
Constructor. Takes the x and the y data and tries to fit a Boltzmann to it.
:param xvalues: numpy array of x (e.g. contrast) values
:param yvalues: numpy array of y (e.g. firing rate) values
:param initial_params: list of initial parameters, default None to autogenerate
"""
assert(len(xvalues) == len(yvalues))
self.__xvals = xvalues
self.__yvals = yvalues
self.__fit_params = None
self.__initial_params = initial_params
self.__x_sorted = np.unique(self.__xvals)
self.__y_avg = None
self.__y_err = None
self.__do_fit()
@staticmethod
def boltzmann(x, y_max, slope, inflection):
"""
The underlying Boltzmann function.
.. math::
f(x) = y_max / exp{-slope*(x-inflection}
:param x: The x values.
:param y_max: The maximum value.
:param slope: The slope parameter k
:param inflection: the position of the inflection point.
:return: the y values.
"""
y = y_max / (1 + np.exp(-slope * (x - inflection)))
return y
def __do_fit(self):
self.__y_avg = np.zeros(self.__x_sorted.shape)
self.__y_err = np.zeros(self.__x_sorted.shape)
for i, c in enumerate(self.__x_sorted):
self.__y_avg[i] = np.mean(self.__yvals[self.__xvals == c])
self.__y_err[i] = np.std(self.__yvals[self.__xvals == c])
if self.__initial_params:
p = self.__initial_params
else:
p = [np.max(self.__y_avg), 0, 0]
self.__fit_params, _ = curve_fit(self.boltzmann, self.__x_sorted, self.__y_avg, p)
@property
def slope(self) -> float:
r"""
The slope of the linear part of the Boltzmann, i.e.
.. math::
s = f_max $\cdot$ k / 4
:return: the slope.
"""
return self.__fit_params[0] * self.__fit_params[1] / 4
@property
def parameters(self):
""" fit parameters
:return: The fit parameters.
"""
return self.__fit_params
@property
def x_data(self):
""" The x data sorted and unique used for fitting.
:return: the x data
"""
return self.__x_sorted
@property
def y_data(self):
"""
the Y data used for fitting, i.e. the average rate in the specified time window sorted by the x data.
:return: the average and the standard deviation of the y data
"""
return self.__y_avg, self.__y_err
def solve(self, xvalues=None):
if not xvalues:
xvalues = self.__x_sorted
return self.boltzmann(xvalues, *self.__fit_params)
class StimSpikesFile:
def __init__(self, filename):
if "stimspikes-1.dat" not in filename:
filename += os.path.join(os.path.sep, "stimspikes1.dat")
if not os.path.exists(filename):
raise ValueError("StimSpikesFile: the given file %s does not exist!" % filename)
self._filename = filename
self._data_map = self.__parse_file(filename)
def __parse_file(self, filename):
with open(filename, 'r') as f:
lines = f.readlines()
index_map = {}
trial_data = []
index = 0
trial = 0
for l in lines:
l = l.strip()
if "index:" in l:
if len(trial_data) > 0:
index_map[(index, trial)] = trial_data
trial_data = []
index = int(l[1:].strip().split(":")[-1])
if "trial:" in l:
if len(trial_data) > 0:
index_map[(index, trial)] = trial_data
trial_data = []
trial = int(l[1:].strip().split(":")[-1])
if len(l) > 0 and "#" not in l:
trial_data.append(float(l)/1000)
return index_map
def get(self, run_index, trial_index):
if tuple([run_index, trial_index]) not in self._data_map.keys():
print("Data not found for run %i and trial %i:" % (run_index, trial_index))
return None
return self._data_map[(run_index, trial_index)]