[buffer] adding edge cases
This commit is contained in:
parent
19c6b90d5c
commit
ff84d63fe1
@ -1,11 +1,18 @@
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
|
||||
from pyrelacs.util.logging import config_logging
|
||||
|
||||
log = config_logging()
|
||||
|
||||
|
||||
class CircBuffer:
|
||||
def __init__(self, size: int, channels: int = 1):
|
||||
self._size = size
|
||||
self._channels = channels
|
||||
self._buffer = np.zeros((channels, size), dtype=np.double) # or dtype of your choice
|
||||
self._buffer = np.zeros(
|
||||
(channels, size), dtype=np.double
|
||||
) # or dtype of your choice
|
||||
self._index = [0 for i in range(channels)]
|
||||
self._is_full = [False for i in range(channels)]
|
||||
self._totalcount = [0 for i in range(channels)]
|
||||
@ -38,10 +45,14 @@ class CircBuffer:
|
||||
Return all valid values from the specified channel
|
||||
"""
|
||||
if self._is_full[channel]:
|
||||
return np.concatenate((self._buffer[channel, self._index[channel]:],
|
||||
self._buffer[channel, :self._index[channel]]))
|
||||
return np.concatenate(
|
||||
(
|
||||
self._buffer[channel, self._index[channel] :],
|
||||
self._buffer[channel, : self._index[channel]],
|
||||
)
|
||||
)
|
||||
else:
|
||||
return self._buffer[channel, :self._index[channel]]
|
||||
return self._buffer[channel, : self._index[channel]]
|
||||
|
||||
def has_value(self, index, channel):
|
||||
if index < 0 and self.is_full(channel):
|
||||
@ -52,59 +63,93 @@ class CircBuffer:
|
||||
if index >= self.size:
|
||||
return False
|
||||
|
||||
# test if the ring buffer is at the start but
|
||||
# and the index is greater than the write index
|
||||
if index > self.write_index(channel) and self.is_full(channel):
|
||||
return True
|
||||
elif index > self.write_index(channel) and not self.is_full(channel):
|
||||
raise IndexError("Index has no value, not written")
|
||||
|
||||
if index == self.write_index(channel) and self._totalcount[channel] == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
def valid_range(self, channel: int = 0) -> Tuple[int, int]:
|
||||
"""
|
||||
Return the start index that are valid within the buffer
|
||||
|
||||
def valid_range(self, channel: int = 0):
|
||||
""" Return the start index that are valid within the buffer
|
||||
Parameters
|
||||
----------
|
||||
channel : int
|
||||
channel of the buffer
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]
|
||||
start, extend of the valid range
|
||||
"""
|
||||
start = 0
|
||||
count = 0
|
||||
extend = 0
|
||||
if self._totalcount[channel] == 0:
|
||||
return start, count
|
||||
return start, extend
|
||||
|
||||
if not self.is_full(channel):
|
||||
count = self._totalcount[channel]
|
||||
extend = self._totalcount[channel]
|
||||
else:
|
||||
count = self.size
|
||||
return start, count
|
||||
extend = self.size
|
||||
return start, extend
|
||||
|
||||
def get(self, index: int = -1, channel: int = 0):
|
||||
# easy case first, we can spare the effort of further checking
|
||||
if index >= 0 and index < self.write_index(channel):
|
||||
return self._buffer[channel, index]
|
||||
if self.has_value(index, channel):
|
||||
return self._buffer[channel, index]
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Invalid index {index} on ring buffer for channel{channel}"
|
||||
)
|
||||
|
||||
if index < 0:
|
||||
log.debug("index is smaller than 0")
|
||||
log.debug(f"{self.write_index()}")
|
||||
index = self.write_index() - 1
|
||||
if self.has_value(index, channel):
|
||||
log.debug("index has a value")
|
||||
return self._buffer[channel, index]
|
||||
else:
|
||||
raise IndexError(f"Invalid index {index} on ring buffer for channel{channel}")
|
||||
raise IndexError(
|
||||
f"Invalid index {index} on ring buffer for channel{channel}"
|
||||
)
|
||||
|
||||
def read(self, start, count=1, channel=0):
|
||||
""" Reads a numpy array from buffer"""
|
||||
"""Reads a numpy array from buffer"""
|
||||
if start < 0 or count < 0:
|
||||
raise IndexError(f"Invalid start ({start}) or count ({count}) for channel{channel}")
|
||||
raise IndexError(
|
||||
f"Invalid start ({start}) or count ({count}) for channel{channel}"
|
||||
)
|
||||
|
||||
if count == 1:
|
||||
return np.array(self.get(start, channel))
|
||||
|
||||
vs, vc = self.valid_range(channel)
|
||||
if start > self._totalcount[channel]:
|
||||
raise IndexError(f"Invalid start index {start} is invalid with totalcount {self._totalcount[channel]} for channel{channel}")
|
||||
raise IndexError(
|
||||
f"Invalid start index {start} is invalid with totalcount {self._totalcount[channel]} for channel{channel}"
|
||||
)
|
||||
if start > self.size:
|
||||
raise IndexError(f"Invalid start index {start} for buffer with size {self.size}")
|
||||
raise IndexError(
|
||||
f"Invalid start index {start} for buffer with size {self.size}"
|
||||
)
|
||||
if count > self.size:
|
||||
count = self.size
|
||||
if count > vc:
|
||||
count = vc
|
||||
|
||||
if (start + count) < self.size:
|
||||
return self._buffer[channel, start:start + count]
|
||||
return self._buffer[channel, start : start + count]
|
||||
else:
|
||||
return np.concatenate((self._buffer[channel, start:],
|
||||
self._buffer[channel, :count - self.size + start]))
|
||||
return np.concatenate(
|
||||
(
|
||||
self._buffer[channel, start:],
|
||||
self._buffer[channel, : count - self.size + start],
|
||||
)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user