__all__ = ["CommandRecord", "metered_bytes", "append_inputs_gen"]
from asyncio import Queue, Task, create_task, sleep
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import AsyncIterable, Iterable, Self
from streamstore.schemas import ONE_MIB, AppendInput, Record, SequencedRecord
[docs]
class CommandRecord:
"""
Factory class for creating `command records <https://s2.dev/docs/stream#command-records>`_.
"""
FENCE = b"fence"
TRIM = b"trim"
[docs]
@staticmethod
def fence(token: str) -> Record:
"""
Create a fence command record.
Args:
token: Fencing token. Its UTF-8 byte count must not exceed 36 bytes. If empty, clears
the previously set token.
"""
encoded_token = token.encode()
if len(encoded_token) > 36:
raise ValueError("UTF-8 byte count of fencing token exceeds 36 bytes")
return Record(body=encoded_token, headers=[(bytes(), CommandRecord.FENCE)])
[docs]
@staticmethod
def trim(desired_first_seq_num: int) -> Record:
"""
Create a trim command record.
Args:
desired_first_seq_num: Sequence number for the first record to exist after trimming
preceeding records in the stream.
Note:
If ``desired_first_seq_num`` was smaller than the sequence number for the first existing
record in the stream, trimming doesn't happen.
"""
return Record(
body=desired_first_seq_num.to_bytes(8),
headers=[(bytes(), CommandRecord.TRIM)],
)
[docs]
def metered_bytes(records: Iterable[Record | SequencedRecord]) -> int:
"""
Each record is metered using the following formula:
.. code-block:: python
8 + 2 * len(headers)
+ sum((len(name) + len(value)) for (name, value) in headers)
+ len(body)
"""
return sum(
(
8
+ 2 * len(record.headers)
+ sum((len(name) + len(value)) for (name, value) in record.headers)
+ len(record.body)
)
for record in records
)
@dataclass(slots=True)
class _AutoBatcher:
_next_batch_idx: int = field(init=False)
_next_batch: list[Record] = field(init=False)
_next_batch_count: int = field(init=False)
_next_batch_bytes: int = field(init=False)
_linger_queue: Queue[tuple[int, datetime]] | None = field(init=False)
_linger_handler_task: Task | None = field(init=False)
_limits_handler_task: Task | None = field(init=False)
append_input_queue: Queue[AppendInput | None]
match_seq_num: int | None
fencing_token: str | None
max_records_per_batch: int
max_bytes_per_batch: int
max_linger_per_batch: timedelta | None
def __post_init__(self) -> None:
self._next_batch_idx = 0
self._next_batch = []
self._next_batch_count = 0
self._next_batch_bytes = 0
self._linger_queue = Queue() if self.max_linger_per_batch is not None else None
self._linger_handler_task = None
self._limits_handler_task = None
def _accumulate(self, record: Record) -> None:
self._next_batch.append(record)
self._next_batch_count += 1
self._next_batch_bytes += metered_bytes([record])
def _next_append_input(self) -> AppendInput:
append_input = AppendInput(
records=list(self._next_batch),
match_seq_num=self.match_seq_num,
fencing_token=self.fencing_token,
)
self._next_batch.clear()
self._next_batch_count = 0
self._next_batch_bytes = 0
self._next_batch_idx += 1
if self.match_seq_num is not None:
self.match_seq_num = self.match_seq_num + len(append_input.records)
return append_input
async def linger_handler(self) -> None:
if self.max_linger_per_batch is None:
return
if self._linger_queue is None:
return
linger_duration = self.max_linger_per_batch.total_seconds()
prev_linger_start = None
while True:
batch_idx, linger_start = await self._linger_queue.get()
if batch_idx < self._next_batch_idx:
continue
if prev_linger_start is None:
prev_linger_start = linger_start
missed_duration = (linger_start - prev_linger_start).total_seconds()
await sleep(max(linger_duration - missed_duration, 0))
if batch_idx == self._next_batch_idx:
append_input = self._next_append_input()
await self.append_input_queue.put(append_input)
prev_linger_start = linger_start
def _limits_met(self, record: Record) -> bool:
if (
self._next_batch_count + 1 <= self.max_records_per_batch
and self._next_batch_bytes + metered_bytes([record])
<= self.max_bytes_per_batch
):
return False
return True
async def limits_handler(self, records: AsyncIterable[Record]) -> None:
async for record in records:
if self._limits_met(record):
append_input = self._next_append_input()
await self.append_input_queue.put(append_input)
self._accumulate(record)
if self._linger_queue is not None and len(self._next_batch) == 1:
await self._linger_queue.put((self._next_batch_idx, datetime.now()))
if len(self._next_batch) != 0:
append_input = self._next_append_input()
await self.append_input_queue.put(append_input)
await self.append_input_queue.put(None)
def run(self, records: AsyncIterable[Record]) -> None:
if self.max_linger_per_batch is not None:
self._linger_handler_task = create_task(self.linger_handler())
self._limits_handler_task = create_task(self.limits_handler(records))
def cancel(self) -> None:
if self._linger_handler_task is not None:
self._linger_handler_task.cancel()
if self._limits_handler_task is not None:
self._limits_handler_task.cancel()
@dataclass(slots=True)
class _AppendInputAsyncIterator:
append_input_queue: Queue[AppendInput | None]
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> AppendInput:
append_input = await self.append_input_queue.get()
if append_input is None:
raise StopAsyncIteration
return append_input