from typing import Tuple
import numpy as np
from IPython import embed


class CircBuffer:
    def __init__(self, size: int, channels: int = 1, samplerate: int = 40_000):
        self.__size = size
        self.__channels = channels
        self.__samplereate = samplerate
        self.__buffer = np.zeros(
            (channels, size), dtype=np.double
        )  # or dtype of your choice
        self.__time = np.zeros((channels, size), dtype=np.double)
        self.__index = [0 for i in range(channels)]
        self.__is_full = [False for i in range(channels)]
        self.__totalcount = [0 for i in range(channels)]
        self.__overflows = [0 for i in range(channels)]
        self.__read_increment = samplerate * 0.1

    @property
    def size(self):
        return self.__size

    @property
    def samplerate(self):
        return self.__samplereate

    @property
    def channel_count(self):
        return self.__channels

    def totalcount(self, channel: int = 0):
        return self.__totalcount[channel]

    def is_full(self, channel: int = 0):
        return self.__is_full[channel]

    def write_index(self, channel: int = 0):
        return self.__index[channel]

    def append(self, item, channel: int = 0):
        self.__buffer[channel, self.write_index(channel)] = item
        self.__index[channel] = (self.write_index(channel) + 1) % self.__size
        self.__totalcount[channel] += 1
        self.__time[channel, self.write_index(channel)] = (
            self.__time[channel, self.write_index(channel) - 1] + 1 / self.__samplereate
        )
        if self.__index[channel] == 0:
            self.__is_full[channel] = True
            self.__overflows[channel] += 1

    def get_all(self, channel: int = 0):
        """
        Return all valid values from the specified channel
        """
        if self.__is_full[channel]:
            return np.concatenate(
                (
                    self.__buffer[channel, self.__index[channel] :],
                    self.__buffer[channel, : self.__index[channel]],
                )
            )
        else:
            return self.__buffer[channel, : self.__index[channel]]

    def has_value(self, index, channel):
        if index <= 0 and self.is_full(channel):
            return True
        elif index < 0 and not self.is_full(channel):
            return False

        if index >= self.size:
            return False

        # test if the ring buffer is at the start but
        # and the index is greater than the write index
        if index > self.write_index(channel) and self.is_full(channel):
            return True
        elif index >= self.write_index(channel) and not self.is_full(channel):
            raise IndexError("Index has no value, not written")

        if index == self.write_index(channel) and self.__totalcount[channel] == 0:
            return False

        return True

    def valid_range(self, channel: int = 0) -> Tuple[int, int]:
        """
        Return the start index and the extend that are valid within the buffer

        Parameters
        ----------
        channel : int
            channel of the buffer

        Returns
        -------
        Tuple[int, int]
            start, extend of the valid range
        """
        start = 0
        extend = 0
        if self.__totalcount[channel] == 0:
            return start, extend

        if not self.is_full(channel):
            extend = self.__totalcount[channel]
        else:
            extend = self.size
        return start, extend

    def get(self, index: int = -1, channel: int = 0) -> Tuple[np.double, float]:
        # easy case first, we can spare the effort of further checking
        if index >= 0 and index <= self.write_index(channel):
            if self.has_value(index, channel):
                return (self.__buffer[channel, index], self.__time[index])
            else:
                raise IndexError(
                    f"Invalid index {index} on ring buffer for channel{channel}"
                )

        if index < 0:
            index = self.write_index() - 1
        if self.has_value(index, channel):
            return (self.__buffer[channel, index], self.__time[index])
        else:
            raise IndexError(
                f"Invalid index {index} on ring buffer for channel{channel}"
            )

    def read(self, start, extend=1, channel=0):
        """Reads a numpy array from buffer"""
        if extend < 0:
            raise IndexError(f"Invalid  extend ({extend}) for channel {channel}")
        if not self.is_full(channel):
            if start < 0:
                raise IndexError(f"Invalid  start ({start}) for channel {channel}")
        else:
            if start < 0:
                start = start + self.size

        if extend == 1:
            return np.array(self.get(start, channel))

        vs, vc = self.valid_range(channel)
        if start > self.__totalcount[channel]:
            raise IndexError(
                f"Invalid start index {start} is invalid with totalcount {self.__totalcount[channel]} for channel{channel}"
            )
        if start > self.size:
            raise IndexError(
                f"Invalid start index {start} for buffer with size {self.size}"
            )

        if extend > vc:
            extend = vc

        if (start + extend) > self.__totalcount[channel]:
            raise IndexError(
                f" Invalid range, extended over the totalcount of the buffer {self.__totalcount[channel]}"
            )

        if (start + extend) < self.size:
            return (
                self.__time[channel, start : start + extend],
                self.__buffer[channel, start : start + extend],
            )
        else:
            return (
                np.concatenate(
                    (
                        self.__time[channel, start:],
                        self.__time[channel, : extend - self.size + start],
                    )
                ),
                np.concatenate(
                    (
                        self.__buffer[channel, start:],
                        self.__buffer[channel, : extend - self.size + start],
                    )
                ),
            )