diff --git a/fishbook/__init__.py b/fishbook/__init__.py index ee7afae..237b5a0 100644 --- a/fishbook/__init__.py +++ b/fishbook/__init__.py @@ -1,4 +1,4 @@ from .fishbook import * -from .reproclasses import BaselineData +from .reproclasses import BaselineData, FIData import fishbook.database as database __all__ = ['fishbook', 'database'] \ No newline at end of file diff --git a/fishbook/reproclasses.py b/fishbook/reproclasses.py index 83f0cb4..5b0f934 100644 --- a/fishbook/reproclasses.py +++ b/fishbook/reproclasses.py @@ -2,6 +2,7 @@ from fishbook.fishbook import Dataset, RePro, Stimulus import numpy as np import nixio as nix from scipy.stats import circstd +from scipy.optimize import curve_fit import os import subprocess from tqdm import tqdm @@ -283,6 +284,52 @@ class BaselineData: return np.asarray(data) +class BoltzmannFit(): + def __init__(self, xvalues: np.ndarray, yvalues: np.ndarray): + assert(len(xvalues) == len(yvalues)) + self.__xvals = xvalues + self.__yvals = yvalues + self.__fit_params = None + self.__x_sorted = np.unique(self.__xvals) + self.__y_avg = None + self.__y_err = None + self.__do_fit() + + @staticmethod + def boltzmann(x, f_max, slope, inflection): + y = f_max / (1 + np.exp(-slope * (x - inflection))) + return y + + def __do_fit(self): + self.__y_avg = np.zeros(self.__x_sorted.shape) + self.__y_err = np.zeros(self.__x_sorted.shape) + for i, c in enumerate(self.__x_sorted): + self.__y_avg[i] = np.mean(self.__yvals[self.__xvals == c]) + self.__y_err[i] = np.std(self.__yvals[self.__xvals == c]) + self.__fit_params, _ = curve_fit(self.boltzmann, self.__x_sorted, self.__y_avg, [np.max(self.__y_avg), 0, 0]) + + @property + def slope(self): + return self.__fit_params[0] * self.__fit_params[1] / 4 + + @property + def parameters(self): + return self.__fit_params + + @property + def x_data(self): + return self.__x_sorted + + @property + def y_data(self): + return self.__y_avg, self.__y_err + + def solve(self, xvalues=None): + if not xvalues: + xvalues = self.__x_sorted + return self.boltzmann(xvalues, *self.__fit_params) + + class FIData: def __init__(self, dataset: Dataset): self.__spike_data = []