from typing import Tuple
import numpy as np
from IPython import embed


class CircBuffer:
    def __init__(self, size: int, channels: int = 1, samplerate: float = 40_000):
        self.__size = size
        self.__channels = channels
        self.__samplereate = samplerate
        self.__buffer = np.zeros(
            (channels, size), dtype=np.double
        )  # or dtype of your choice
        self.__index = [0 for i in range(channels)]
        self.__is_full = [False for i in range(channels)]
        self.__totalcount = [0 for i in range(channels)]
        self.__overflows = [0 for i in range(channels)]

    @property
    def size(self):
        return self.__size

    @property
    def samplerate(self):
        return self.__samplereate

    @property
    def channel_count(self):
        return self.__channels

    def totocount(self, channel: int = 0):
        return self.__totalcount[channel]

    def is_full(self, channel: int = 0):
        return self.__is_full[channel]

    def write_index(self, channel: int = 0):
        return self.__index[channel]

    def append(self, item, channel: int = 0):
        self.__buffer[channel, self.write_index(channel)] = item
        self.__index[channel] = (self.write_index(channel) + 1) % self.__size
        self.__totalcount[channel] += 1
        if self.__index[channel] == 0:
            self.__is_full[channel] = True
            self.__overflows[channel] += 1

    def get_all(self, channel: int = 0):
        """
        Return all valid values from the specified channel
        """
        if self.__is_full[channel]:
            return np.concatenate(
                (
                    self.__buffer[channel, self.__index[channel] :],
                    self.__buffer[channel, : self.__index[channel]],
                )
            )
        else:
            return self.__buffer[channel, : self.__index[channel]]

    def has_value(self, index, channel):
        if index <= 0 and self.is_full(channel):
            return True
        elif index < 0 and not self.is_full(channel):
            return False

        if index >= self.size:
            return False

        # test if the ring buffer is at the start but
        # and the index is greater than the write index
        if index > self.write_index(channel) and self.is_full(channel):
            return True
        elif index >= self.write_index(channel) and not self.is_full(channel):
            raise IndexError("Index has no value, not written")

        if index == self.write_index(channel) and self.__totalcount[channel] == 0:
            return False

        return True

    def valid_range(self, channel: int = 0) -> Tuple[int, int]:
        """
        Return the start index and the extend that are valid within the buffer

        Parameters
        ----------
        channel : int
            channel of the buffer

        Returns
        -------
        Tuple[int, int]
            start, extend of the valid range
        """
        start = 0
        extend = 0
        if self.__totalcount[channel] == 0:
            return start, extend

        if not self.is_full(channel):
            extend = self.__totalcount[channel]
        else:
            extend = self.size
        return start, extend

    def get(self, index: int = -1, channel: int = 0):
        # easy case first, we can spare the effort of further checking
        if index >= 0 and index <= self.write_index(channel):
            if self.has_value(index, channel):
                return self.__buffer[channel, index]
            else:
                raise IndexError(
                    f"Invalid index {index} on ring buffer for channel{channel}"
                )

        if index < 0:
            index = self.write_index() - 1
        if self.has_value(index, channel):
            return self.__buffer[channel, index]
        else:
            raise IndexError(
                f"Invalid index {index} on ring buffer for channel{channel}"
            )

    def read(self, start, count=1, channel=0):
        """Reads a numpy array from buffer"""
        if start < 0 or count < 0:
            raise IndexError(
                f"Invalid start ({start}) or count ({count}) for channel{channel}"
            )

        if count == 1:
            return np.array(self.get(start, channel))

        vs, vc = self.valid_range(channel)
        if start > self.__totalcount[channel]:
            raise IndexError(
                f"Invalid start index {start} is invalid with totalcount {self.__totalcount[channel]} for channel{channel}"
            )
        if start > self.size:
            raise IndexError(
                f"Invalid start index {start} for buffer with size {self.size}"
            )
        if count > self.size:
            count = self.size
        if count > vc:
            count = vc

        if (start + count) < self.size:
            return self.__buffer[channel, start : start + count]
        else:
            return np.concatenate(
                (
                    self.__buffer[channel, start:],
                    self.__buffer[channel, : count - self.size + start],
                )
            )