import numpy as np
from numba import jit


def line(x, m, c):
    return m*x + c


def inverse_line(x, m, c):
    return (x-c)/m


def clipped_line(x, m, c):
    return np.clip(m*x + c, 0, None)


def inverse_clipped_line(x, a, b):
    if clipped_line(x, a, b) == 0:
        raise ValueError("Value undefined in inverse_clipped_line.")

    return (x-a)/b


def exponential_function(x, a, b, c):
    return (a-c)*np.exp(-x/b)+c


def upper_boltzmann(x, f_max, k, x_zero):
    return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None)


def full_boltzmann(x, f_max, f_min, k, x_zero):
    return (f_max-f_min) * (1 / (1 + np.power(np.e, -k * (x - x_zero)))) + f_min


def full_boltzmann_straight_slope(f_max, f_min, k, x_zero=0):
    return (f_max-f_min)*k*1/2


def derivative_full_boltzmann(x, f_max, f_min, k, x_zero):
    res = (f_max - f_min) * k * np.power(np.e, -k * (x - x_zero)) / (1 + np.power(np.e, -k * (x - x_zero))**2)
    return res


def inverse_full_boltzmann(x, f_max, f_min, k, x_zero):
    if x < f_min or x > f_max:
        raise ValueError("Value undefined in inverse_full_boltzmann")

    return -(np.log((f_max-f_min) / (x - f_min) - 1) / k) + x_zero


def gauss(x, a, x0, sigma):
    return a*np.e**(-(x-x0)**2/(2*sigma**2))


def two_gauss(x, a_1, x0_1, sigma_1, a_2, x0_2, sigma_2):
    return a_1 * np.e ** (-(x - x0_1) ** 2 / (2 * sigma_1 ** 2)) + a_2 * np.e ** (-(x - x0_2) ** 2 / (2 * sigma_2 ** 2))


@jit(nopython=True)  # useful in less that 1000x10000 calls (1000 tests with 10k data points)
def rectify(x):
    if x < 0:
        return 0
    return x