2
0
forked from awendt/pyrelacs
minipyrelacs/pyrelacs/dataio/circbuffer.py

184 lines
6.1 KiB
Python

from typing import Tuple
import numpy as np
from IPython import embed
class CircBuffer:
def __init__(self, size: int, channels: int = 1, samplerate: int = 40_000):
self.__size = size
self.__channels = channels
self.__samplereate = samplerate
self.__buffer = np.zeros(
(channels, size), dtype=np.double
) # or dtype of your choice
self.__time = np.zeros((channels, size), dtype=np.double)
self.__index = [0 for i in range(channels)]
self.__is_full = [False for i in range(channels)]
self.__totalcount = [0 for i in range(channels)]
self.__overflows = [0 for i in range(channels)]
self.__read_increment = samplerate * 0.1
@property
def size(self):
return self.__size
@property
def samplerate(self):
return self.__samplereate
@property
def channel_count(self):
return self.__channels
def totalcount(self, channel: int = 0):
return self.__totalcount[channel]
def is_full(self, channel: int = 0):
return self.__is_full[channel]
def write_index(self, channel: int = 0):
return self.__index[channel]
def append(self, item, channel: int = 0):
self.__buffer[channel, self.write_index(channel)] = item
self.__index[channel] = (self.write_index(channel) + 1) % self.__size
self.__totalcount[channel] += 1
self.__time[channel, self.write_index(channel)] = (
self.__time[channel, self.write_index(channel) - 1] + 1 / self.__samplereate
)
if self.__index[channel] == 0:
self.__is_full[channel] = True
self.__overflows[channel] += 1
def get_all(self, channel: int = 0):
"""
Return all valid values from the specified channel
"""
if self.__is_full[channel]:
return np.concatenate(
(
self.__buffer[channel, self.__index[channel] :],
self.__buffer[channel, : self.__index[channel]],
)
)
else:
return self.__buffer[channel, : self.__index[channel]]
def has_value(self, index, channel):
if index <= 0 and self.is_full(channel):
return True
elif index < 0 and not self.is_full(channel):
return False
if index >= self.size:
return False
# test if the ring buffer is at the start but
# and the index is greater than the write index
if index > self.write_index(channel) and self.is_full(channel):
return True
elif index >= self.write_index(channel) and not self.is_full(channel):
raise IndexError("Index has no value, not written")
if index == self.write_index(channel) and self.__totalcount[channel] == 0:
return False
return True
def valid_range(self, channel: int = 0) -> Tuple[int, int]:
"""
Return the start index and the extend that are valid within the buffer
Parameters
----------
channel : int
channel of the buffer
Returns
-------
Tuple[int, int]
start, extend of the valid range
"""
start = 0
extend = 0
if self.__totalcount[channel] == 0:
return start, extend
if not self.is_full(channel):
extend = self.__totalcount[channel]
else:
extend = self.size
return start, extend
def get(self, index: int = -1, channel: int = 0) -> Tuple[np.double, float]:
# easy case first, we can spare the effort of further checking
if index >= 0 and index <= self.write_index(channel):
if self.has_value(index, channel):
return (self.__buffer[channel, index], self.__time[index])
else:
raise IndexError(
f"Invalid index {index} on ring buffer for channel{channel}"
)
if index < 0:
index = self.write_index() - 1
if self.has_value(index, channel):
return (self.__buffer[channel, index], self.__time[index])
else:
raise IndexError(
f"Invalid index {index} on ring buffer for channel{channel}"
)
def read(self, start, extend=1, channel=0):
"""Reads a numpy array from buffer"""
if extend < 0:
raise IndexError(f"Invalid extend ({extend}) for channel {channel}")
if not self.is_full(channel):
if start < 0:
raise IndexError(f"Invalid start ({start}) for channel {channel}")
else:
if start < 0:
start = start + self.size
if extend == 1:
return np.array(self.get(start, channel))
vs, vc = self.valid_range(channel)
if start > self.__totalcount[channel]:
raise IndexError(
f"Invalid start index {start} is invalid with totalcount {self.__totalcount[channel]} for channel{channel}"
)
if start > self.size:
raise IndexError(
f"Invalid start index {start} for buffer with size {self.size}"
)
if extend > vc:
extend = vc
if (start + extend) > self.__totalcount[channel]:
raise IndexError(
f" Invalid range, extended over the totalcount of the buffer {self.__totalcount[channel]}"
)
if (start + extend) < self.size:
return (
self.__time[channel, start : start + extend],
self.__buffer[channel, start : start + extend],
)
else:
return (
np.concatenate(
(
self.__time[channel, start:],
self.__time[channel, : extend - self.size + start],
)
),
np.concatenate(
(
self.__buffer[channel, start:],
self.__buffer[channel, : extend - self.size + start],
)
),
)