From 914a861b26ad44db7bb14f2048ca9cb44f1af5a0 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Wed, 22 Jul 2020 22:43:11 +0200 Subject: [PATCH] restructuring --- fishbook/__init__.py | 4 +- fishbook/frontend/relacs_classes.py | 145 ++-------------------------- fishbook/frontend/util.py | 141 +++++++++++++++++++++++++++ 3 files changed, 148 insertions(+), 142 deletions(-) diff --git a/fishbook/__init__.py b/fishbook/__init__.py index 8f487ec..863e6ce 100644 --- a/fishbook/__init__.py +++ b/fishbook/__init__.py @@ -1,4 +1,2 @@ from fishbook.frontend.frontend_classes import Cell, Subject, Stimulus, Dataset, RePro -#import fishbook.reproclasses as repros -#import fishbook.database as database -__all__ = ['', ''] +from fishbook.frontend.relacs_classes import BaselineData, FIData, FileStimulusData diff --git a/fishbook/frontend/relacs_classes.py b/fishbook/frontend/relacs_classes.py index 5786bb4..559942c 100644 --- a/fishbook/frontend/relacs_classes.py +++ b/fishbook/frontend/relacs_classes.py @@ -1,8 +1,9 @@ -from fishbook.fishbook import Dataset, RePro, Stimulus +from .frontend_classes import Dataset, RePro, Stimulus +from .util import BoltzmannFit, unzip_if_needed, gaussian_kernel, zero_crossings import numpy as np import nixio as nix from scipy.stats import circstd -from scipy.optimize import curve_fit +# from scipy.optimize import curve_fit import os import subprocess from tqdm import tqdm @@ -10,52 +11,6 @@ from tqdm import tqdm from IPython import embed -def _zero_crossings(x, t, interpolate=False): - """get the times at which a signal x - - Args: - x ([type]): [description] - t ([type]): [description] - interpolate (bool, optional): [description]. Defaults to False. - - Returns: - [type]: [description] - """ - dt = t[1] - t[0] - x_shift = np.roll(x, 1) - x_shift[0] = 0.0 - xings = np.where((x >= 0 ) & (x_shift < 0))[0] - crossings = np.zeros(len(xings)) - if interpolate: - for i, tf in enumerate(xings): - if x[tf] > 0.001: - m = (x[tf] - x[tf-1])/dt - crossings[i] = t[tf] - x[tf]/m - elif x[tf] < -0.001: - m = (x[tf + 1] - x[tf]) / dt - crossings[i] = t[tf] - x[tf]/m - else: - crossings[i] = t[tf] - else: - crossings = t[xings] - return crossings - - -def _unzip_if_needed(dataset, tracename='trace-1.raw'): - file_name = os.path.join(dataset, tracename) - if os.path.exists(file_name): - return - if os.path.exists(file_name + '.gz'): - print("\tunzip: %s" % tracename) - subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")]) - - -def gaussian_kernel(sigma, dt): - x = np.arange(-4. * sigma, 4. * sigma, dt) - y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma - return y - - class BaselineData: """ Class representing the Baseline data that has been recorded within a given Dataset. @@ -134,7 +89,7 @@ class BaselineData: eod_frequencies = [] for i in range(self.size): eod, time = self.eod(i) - xings = _zero_crossings(eod, time, interpolate=False) + xings = zero_crossings(eod, time, interpolate=False) eod_frequencies.append(len(xings)/xings[-1]) return 0.0 if len(eod_frequencies) < 1 else np.mean(eod_frequencies) @@ -154,7 +109,7 @@ class BaselineData: return None if len(self.__eod_times) < len(self.__eod_data): eod, time = self.eod(index) - times = _zero_crossings(eod, time, interpolate=interpolate) + times = zero_crossings(eod, time, interpolate=interpolate) else: times = self.__eod_times[index] return times @@ -272,7 +227,7 @@ class BaselineData: def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray: sr = self.__dataset.samplerate - _unzip_if_needed(self.__dataset.data_source, "trace-2.raw") + unzip_if_needed(self.__dataset.data_source, "trace-2.raw") eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32) eod = eod[:int(duration * sr)] return eod @@ -329,94 +284,6 @@ class BaselineData: return np.asarray(data) -class BoltzmannFit: - """ - Class representing a fit of a Boltzmann function to some data. - """ - - def __init__(self, xvalues: np.ndarray, yvalues: np.ndarray, initial_params=None): - """ - Constructor. Takes the x and the y data and tries to fit a Boltzmann to it. - - :param xvalues: numpy array of x (e.g. contrast) values - :param yvalues: numpy array of y (e.g. firing rate) values - :param initial_params: list of initial parameters, default None to autogenerate - """ - assert(len(xvalues) == len(yvalues)) - self.__xvals = xvalues - self.__yvals = yvalues - self.__fit_params = None - self.__initial_params = initial_params - self.__x_sorted = np.unique(self.__xvals) - self.__y_avg = None - self.__y_err = None - self.__do_fit() - - @staticmethod - def boltzmann(x, y_max, slope, inflection): - """ - The underlying Boltzmann function. - .. math:: - f(x) = y_max / \exp{-slope*(x-inflection} - - :param x: The x values. - :param y_max: The maximum value. - :param slope: The slope parameter k - :param inflection: the position of the inflection point. - :return: the y values. - """ - y = y_max / (1 + np.exp(-slope * (x - inflection))) - return y - - def __do_fit(self): - self.__y_avg = np.zeros(self.__x_sorted.shape) - self.__y_err = np.zeros(self.__x_sorted.shape) - for i, c in enumerate(self.__x_sorted): - self.__y_avg[i] = np.mean(self.__yvals[self.__xvals == c]) - self.__y_err[i] = np.std(self.__yvals[self.__xvals == c]) - if self.__initial_params: - p = self.__initial_params - else: - p = [np.max(self.__y_avg), 0, 0] - self.__fit_params, _ = curve_fit(self.boltzmann, self.__x_sorted, self.__y_avg, p) - - @property - def slope(self) -> float: - r""" - The slope of the linear part of the Boltzmann, i.e. - .. math:: - s = f_max $\cdot$ k / 4 - :return: the slope. - """ - return self.__fit_params[0] * self.__fit_params[1] / 4 - - @property - def parameters(self): - """ fit parameters - :return: The fit parameters. - """ - return self.__fit_params - - @property - def x_data(self): - """ The x data sorted and unique used for fitting. - :return: the x data - """ - return self.__x_sorted - - @property - def y_data(self): - """ - the Y data used for fitting, i.e. the average rate in the specified time window sorted by the x data. - :return: the average and the standard deviation of the y data - """ - return self.__y_avg, self.__y_err - - def solve(self, xvalues=None): - if not xvalues: - xvalues = self.__x_sorted - return self.boltzmann(xvalues, *self.__fit_params) - class FIData: """ diff --git a/fishbook/frontend/util.py b/fishbook/frontend/util.py index 0b0f477..1a00d8a 100644 --- a/fishbook/frontend/util.py +++ b/fishbook/frontend/util.py @@ -1,3 +1,7 @@ +import numpy as np +from scipy.optimize import curve_fit + + def safe_get_val(dictionary:dict, key, default=None): return dictionary[key] if key in dictionary.keys() else default @@ -7,3 +11,140 @@ def results_check(results, id, text="ID"): raise ValueError("%s %s does not exist!" % (text, id)) elif len(results) > 1: raise ValueError("%s %s is not unique!" % (text, id)) + + +def zero_crossings(x, t, interpolate=False): + """get the times at which a signal x + + Args: + x ([type]): [description] + t ([type]): [description] + interpolate (bool, optional): [description]. Defaults to False. + + Returns: + [type]: [description] + """ + dt = t[1] - t[0] + x_shift = np.roll(x, 1) + x_shift[0] = 0.0 + xings = np.where((x >= 0 ) & (x_shift < 0))[0] + crossings = np.zeros(len(xings)) + if interpolate: + for i, tf in enumerate(xings): + if x[tf] > 0.001: + m = (x[tf] - x[tf-1])/dt + crossings[i] = t[tf] - x[tf]/m + elif x[tf] < -0.001: + m = (x[tf + 1] - x[tf]) / dt + crossings[i] = t[tf] - x[tf]/m + else: + crossings[i] = t[tf] + else: + crossings = t[xings] + return crossings + + +def unzip_if_needed(dataset, tracename='trace-1.raw'): + file_name = os.path.join(dataset, tracename) + if os.path.exists(file_name): + return + if os.path.exists(file_name + '.gz'): + print("\tunzip: %s" % tracename) + subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")]) + + +def gaussian_kernel(sigma, dt): + x = np.arange(-4. * sigma, 4. * sigma, dt) + y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma + return y + + + +class BoltzmannFit: + """ + Class representing a fit of a Boltzmann function to some data. + """ + + def __init__(self, xvalues: np.ndarray, yvalues: np.ndarray, initial_params=None): + """ + Constructor. Takes the x and the y data and tries to fit a Boltzmann to it. + + :param xvalues: numpy array of x (e.g. contrast) values + :param yvalues: numpy array of y (e.g. firing rate) values + :param initial_params: list of initial parameters, default None to autogenerate + """ + assert(len(xvalues) == len(yvalues)) + self.__xvals = xvalues + self.__yvals = yvalues + self.__fit_params = None + self.__initial_params = initial_params + self.__x_sorted = np.unique(self.__xvals) + self.__y_avg = None + self.__y_err = None + self.__do_fit() + + @staticmethod + def boltzmann(x, y_max, slope, inflection): + """ + The underlying Boltzmann function. + .. math:: + f(x) = y_max / \exp{-slope*(x-inflection} + + :param x: The x values. + :param y_max: The maximum value. + :param slope: The slope parameter k + :param inflection: the position of the inflection point. + :return: the y values. + """ + y = y_max / (1 + np.exp(-slope * (x - inflection))) + return y + + def __do_fit(self): + self.__y_avg = np.zeros(self.__x_sorted.shape) + self.__y_err = np.zeros(self.__x_sorted.shape) + for i, c in enumerate(self.__x_sorted): + self.__y_avg[i] = np.mean(self.__yvals[self.__xvals == c]) + self.__y_err[i] = np.std(self.__yvals[self.__xvals == c]) + if self.__initial_params: + p = self.__initial_params + else: + p = [np.max(self.__y_avg), 0, 0] + self.__fit_params, _ = curve_fit(self.boltzmann, self.__x_sorted, self.__y_avg, p) + + @property + def slope(self) -> float: + r""" + The slope of the linear part of the Boltzmann, i.e. + .. math:: + s = f_max $\cdot$ k / 4 + :return: the slope. + """ + return self.__fit_params[0] * self.__fit_params[1] / 4 + + @property + def parameters(self): + """ fit parameters + :return: The fit parameters. + """ + return self.__fit_params + + @property + def x_data(self): + """ The x data sorted and unique used for fitting. + :return: the x data + """ + return self.__x_sorted + + @property + def y_data(self): + """ + the Y data used for fitting, i.e. the average rate in the specified time window sorted by the x data. + :return: the average and the standard deviation of the y data + """ + return self.__y_avg, self.__y_err + + def solve(self, xvalues=None): + if not xvalues: + xvalues = self.__x_sorted + return self.boltzmann(xvalues, *self.__fit_params) +