from FiCurve import FICurve
from CellData import CellData
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import os
import numpy as np
import functions as fu


class Adaption:

    def __init__(self, cell_data: CellData, fi_curve: FICurve = None):
        self.cell_data = cell_data
        if fi_curve is None:
            self.fi_curve = FICurve(cell_data)
        else:
            self.fi_curve = fi_curve

        # [[a, tau_eff, c], [], [a, tau_eff, c], ...]
        self.exponential_fit_vars = []
        self.tau_real = []

        self.fit_exponential()
        self.calculate_tau_from_tau_eff()

    def fit_exponential(self, length_of_fit=0.1):
        mean_frequencies = self.cell_data.get_mean_isi_frequencies()
        time_axes = self.cell_data.get_time_axes_mean_frequencies()
        for i in range(len(mean_frequencies)):
            start_idx = self.__find_start_idx_for_exponential_fit(i)

            if start_idx == -1:
                print("start index negative")
                self.exponential_fit_vars.append([])
                continue

            # shorten length of fit to stay in stimulus region if given length is too long
            sampling_interval = self.cell_data.get_sampling_interval()
            used_length_of_fit = length_of_fit
            if (start_idx * sampling_interval) - self.cell_data.get_delay() + length_of_fit > self.cell_data.get_stimulus_end():
                print(start_idx * sampling_interval, "start - end",  start_idx * sampling_interval + length_of_fit)
                print("Shortened length of fit to keep it in the stimulus region!")
                used_length_of_fit = self.cell_data.get_stimulus_end() - (start_idx * sampling_interval)

            end_idx = start_idx + int(used_length_of_fit/sampling_interval)
            y_values = mean_frequencies[i][start_idx:end_idx+1]
            x_values = time_axes[i][start_idx:end_idx+1]

            tau = self.__approximate_tau_for_exponential_fit(x_values, y_values, i)

            # start the actual fit:
            try:
                p0 = (self.fi_curve.f_zeros[i], tau, self.fi_curve.f_infinities[i])
                popt, pcov = curve_fit(fu.exponential_function, x_values, y_values,
                                       p0=p0, maxfev=10000, bounds=([-np.inf, 0, -np.inf], [np.inf, np.inf, np.inf]))
            except RuntimeError:
                print("RuntimeError happened in fit_exponential.")
                self.exponential_fit_vars.append([])
                continue

            # Obviously a bad fit - time constant, expected in range 3-10ms, has value over 1 second or is negative
            if abs(popt[1] > 1) or popt[1] < 0:
                print("detected an obviously bad fit")
                self.exponential_fit_vars.append([])
            else:
                self.exponential_fit_vars.append(popt)

    def __approximate_tau_for_exponential_fit(self, x_values, y_values, mean_freq_idx):
        if self.fi_curve.f_infinities[mean_freq_idx] < self.fi_curve.f_baselines[mean_freq_idx] * 0.95:
            test_val = [y > 0.65 * self.fi_curve.f_infinities[mean_freq_idx] for y in y_values]
        else:
            test_val = [y < 0.65 * self.fi_curve.f_zeros[mean_freq_idx] for y in y_values]

        try:
            idx = test_val.index(True)
            if idx == 0:
                idx = 1
            tau = x_values[idx] - x_values[0]
        except ValueError:
            tau = x_values[-1] - x_values[0]

        return tau

    def __find_start_idx_for_exponential_fit(self, mean_freq_idx):
        time_axes = self.cell_data.get_time_axes_mean_frequencies()[mean_freq_idx]
        stimulus_start_idx = int((self.cell_data.get_stimulus_start() + time_axes[0]) / self.cell_data.get_sampling_interval())
        if self.fi_curve.f_infinities[mean_freq_idx] > self.fi_curve.f_baselines[mean_freq_idx] * 1.1:
            # start setting starting variables for the fit
            # search for the start_index by searching for the max
            j = 0
            while True:
                try:
                    if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j] == self.fi_curve.f_zeros[mean_freq_idx]:
                        start_idx = stimulus_start_idx + j
                        break
                except IndexError as e:
                    return -1

                j += 1

        elif self.fi_curve.f_infinities[mean_freq_idx] < self.fi_curve.f_baselines[mean_freq_idx] * 0.9:
            # start setting starting variables for the fit
            # search for start by finding the end of the minimum
            found_min = False
            j = int(0.05 / self.cell_data.get_sampling_interval())
            nothing_to_fit = False
            while True:
                if not found_min:
                    if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j] == self.fi_curve.f_zeros[mean_freq_idx]:
                        found_min = True
                else:
                    if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j + 1] > self.fi_curve.f_zeros[mean_freq_idx]:
                        start_idx = stimulus_start_idx + j
                        break
                if j > 0.1 / self.cell_data.get_sampling_interval():
                    # no rise in freq until to close to the end of the stimulus (to little place to fit)
                    return -1
                j += 1

            if nothing_to_fit:
                return -1
        else:
            # there is nothing to fit to:
            return -1

        return start_idx

    def calculate_tau_from_tau_eff(self):
        tau_effs = []
        for i in range(len(self.exponential_fit_vars)):
            if len(self.exponential_fit_vars[i]) == 0:
                continue
            tau_effs.append(self.exponential_fit_vars[i][1])

        f_infinity_slope = self.fi_curve.get_f_infinity_slope()
        # --- old way to calculate with the fi slope at  middle of the fi curve
        # fi_curve_slope = self.fi_curve.get_fi_curve_slope_of_straight()
        # self.tau_real = np.median(tau_effs) * (fi_curve_slope / f_infinity_slope)

        # print("fi_slope to f_inf slope:", fi_curve_slope/f_infinity_slope)
        # print("fi_slope:", fi_curve_slope, "f_inf slope:", f_infinity_slope)
        # print("current tau: {:.1f}ms".format(np.median(tau_effs) * (fi_curve_slope / f_infinity_slope) * 1000))

        # new way to calculate with the fi curve slope at the intersection point of it and the f_inf line
        factor = self.fi_curve.get_fi_curve_slope_at_f_zero_intersection() / f_infinity_slope
        self.tau_real = np.median(tau_effs) * factor
        print("###### tau: {:.1f}ms".format(self.tau_real*1000), "other f_0 slope:", self.fi_curve.get_fi_curve_slope_at_f_zero_intersection())

    def get_tau_real(self):
        return np.median(self.tau_real)

    def get_tau_effs(self):
        return [ex_vars[1] for ex_vars in self.exponential_fit_vars if ex_vars != []]

    def plot_exponential_fits(self, save_path: str = None, indices: list = None, delete_previous: bool = False):
        if delete_previous:
            for val in self.cell_data.get_fi_contrasts():

                prev_path = save_path + "mean_freq_exp_fit_contrast:" + str(round(val, 3)) + ".png"

                if os.path.exists(prev_path):
                    os.remove(prev_path)

        for i in range(len(self.cell_data.get_fi_contrasts())):
            if indices is not None and i not in indices:
                continue

            if self.exponential_fit_vars[i] == []:
                print("no fit vars for index!")
                continue

            plt.plot(self.cell_data.get_time_axes_mean_frequencies()[i], self.cell_data.get_mean_isi_frequencies()[i])
            vars = self.exponential_fit_vars[i]
            fit_x = np.arange(0, 0.4, self.cell_data.get_sampling_interval())
            plt.plot(fit_x, [fu.exponential_function(x, vars[0], vars[1], vars[2]) for x in fit_x])
            plt.ylim([0, max(self.fi_curve.f_zeros[i], self.fi_curve.f_baselines[i])*1.1])
            plt.xlabel("Time [s]")
            plt.ylabel("Frequency [Hz]")

            if save_path is None:
                plt.show()
            else:
                plt.savefig(save_path + "mean_freq_exp_fit_contrast:" + str(round(self.cell_data.get_fi_contrasts()[i], 3)) + ".png")

            plt.close()