from typing import Tuple import numpy as np from IPython import embed from pyqtgraph.Qt.QtCore import QMutex class CircBuffer: def __init__( self, size: int, channels: int = 1, samplerate: float = 40_000.0, mutex: QMutex = QMutex(), ): self.__size = size self.__channels = channels self.__samplereate = samplerate self.__buffer = np.zeros( (channels, size), dtype=np.double ) # or dtype of your choice self.__time = np.zeros((channels, size), dtype=np.double) self.__index = [0 for i in range(channels)] self.__is_full = [False for i in range(channels)] self.__totalcount = [0 for i in range(channels)] self.__overflows = [0 for i in range(channels)] self.mutex = mutex @property def size(self): return self.__size @property def samplerate(self): return self.__samplereate @property def channel_count(self): return self.__channels def totalcount(self, channel: int = 0): return self.__totalcount[channel] def is_full(self, channel: int = 0): return self.__is_full[channel] def write_index(self, channel: int = 0): return self.__index[channel] def append(self, item, channel: int = 0): self.mutex.lock() self.__buffer[channel, self.write_index(channel)] = item self.__index[channel] = (self.write_index(channel) + 1) % self.__size self.__totalcount[channel] += 1 self.__time[channel, self.write_index(channel)] = ( self.__time[channel, self.write_index(channel) - 1] + 1 / self.__samplereate ) if self.__index[channel] == 0: self.__is_full[channel] = True self.__overflows[channel] += 1 self.mutex.unlock() def get_all(self, channel: int = 0): """ Return all valid values from the specified channel """ if self.__is_full[channel]: return np.concatenate( ( self.__buffer[channel, self.__index[channel] :], self.__buffer[channel, : self.__index[channel]], ) ) else: return self.__buffer[channel, : self.__index[channel]] def has_value(self, index, channel): if index <= 0 and self.is_full(channel): return True elif index < 0 and not self.is_full(channel): return False if index >= self.size: return False # test if the ring buffer is at the start but # and the index is greater than the write index if index > self.write_index(channel) and self.is_full(channel): return True elif index >= self.write_index(channel) and not self.is_full(channel): raise IndexError("Index has no value, not written") if index == self.write_index(channel) and self.__totalcount[channel] == 0: return False return True def valid_range(self, channel: int = 0) -> Tuple[int, int]: """ Return the start index and the extend that are valid within the buffer Parameters ---------- channel : int channel of the buffer Returns ------- Tuple[int, int] start, extend of the valid range """ start = 0 extend = 0 if self.__totalcount[channel] == 0: return start, extend if not self.is_full(channel): extend = self.__totalcount[channel] else: extend = self.size return start, extend def get(self, index: int = -1, channel: int = 0) -> Tuple[np.double, float]: # easy case first, we can spare the effort of further checking if index >= 0 and index <= self.write_index(channel): if self.has_value(index, channel): return (self.__buffer[channel, index], self.__time[channel, index]) else: raise IndexError( f"Invalid index {index} on ring buffer for channel{channel}" ) if index < 0: index = self.write_index() - 1 if self.has_value(index, channel): return (self.__buffer[channel, index], self.__time[channel, index]) else: raise IndexError( f"Invalid index {index} on ring buffer for channel{channel}" ) def read(self, start, extend=1, channel=0): """Reads a numpy array from buffer""" if extend < 0: raise IndexError(f"Invalid extend ({extend}) for channel {channel}") if not self.is_full(channel): if start < 0: raise IndexError(f"Invalid start ({start}) for channel {channel}") else: if start < 0: start = start + self.size if extend == 1: return np.array(self.get(start, channel)) vs, vc = self.valid_range(channel) if start > self.__totalcount[channel]: raise IndexError( f"Invalid start index {start} is invalid with totalcount {self.__totalcount[channel]} for channel{channel}" ) if start > self.size: raise IndexError( f"Invalid start index {start} for buffer with size {self.size}" ) if extend > vc: extend = vc if (start + extend) > self.__totalcount[channel]: raise IndexError( f" Invalid range, extended over the totalcount of the buffer {self.__totalcount[channel]}" ) if (start + extend) < self.size: return ( self.__time[channel, start : start + extend], self.__buffer[channel, start : start + extend], ) else: return ( np.concatenate( ( self.__time[channel, start:], self.__time[channel, : extend - self.size + start], ) ), np.concatenate( ( self.__buffer[channel, start:], self.__buffer[channel, : extend - self.size + start], ) ), )