mirror of
https://github.com/smittix/intercept.git
synced 2026-04-30 17:49:58 -07:00
feat: ship platform UX and reliability upgrades
This commit is contained in:
210
utils/sse.py
210
utils/sse.py
@@ -1,48 +1,170 @@
|
||||
"""Server-Sent Events (SSE) utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
import time
|
||||
from typing import Any, Generator
|
||||
|
||||
|
||||
def sse_stream(
|
||||
data_queue: queue.Queue,
|
||||
timeout: float = 1.0,
|
||||
keepalive_interval: float = 30.0,
|
||||
stop_check: callable = None
|
||||
) -> Generator[str, None, None]:
|
||||
"""Server-Sent Events (SSE) utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generator
|
||||
|
||||
|
||||
@dataclass
|
||||
class _QueueFanoutChannel:
|
||||
"""Internal fanout state for a source queue."""
|
||||
source_queue: queue.Queue
|
||||
source_timeout: float
|
||||
subscribers: set[queue.Queue] = field(default_factory=set)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
distributor: threading.Thread | None = None
|
||||
|
||||
|
||||
_fanout_channels: dict[str, _QueueFanoutChannel] = {}
|
||||
_fanout_channels_lock = threading.Lock()
|
||||
|
||||
|
||||
def _run_fanout(channel: _QueueFanoutChannel) -> None:
|
||||
"""Drain source queue and fan out each message to all subscribers."""
|
||||
while True:
|
||||
try:
|
||||
msg = channel.source_queue.get(timeout=channel.source_timeout)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
with channel.lock:
|
||||
subscribers = tuple(channel.subscribers)
|
||||
|
||||
for subscriber in subscribers:
|
||||
try:
|
||||
subscriber.put_nowait(msg)
|
||||
except queue.Full:
|
||||
# Drop oldest frame for this subscriber and retry once.
|
||||
try:
|
||||
subscriber.get_nowait()
|
||||
subscriber.put_nowait(msg)
|
||||
except (queue.Empty, queue.Full):
|
||||
continue
|
||||
|
||||
|
||||
def _ensure_fanout_channel(
|
||||
channel_key: str,
|
||||
source_queue: queue.Queue,
|
||||
source_timeout: float,
|
||||
) -> _QueueFanoutChannel:
|
||||
"""Get/create a fanout channel and ensure distributor thread is running."""
|
||||
with _fanout_channels_lock:
|
||||
channel = _fanout_channels.get(channel_key)
|
||||
if channel is None:
|
||||
channel = _QueueFanoutChannel(source_queue=source_queue, source_timeout=source_timeout)
|
||||
_fanout_channels[channel_key] = channel
|
||||
|
||||
if channel.distributor is None or not channel.distributor.is_alive():
|
||||
channel.distributor = threading.Thread(
|
||||
target=_run_fanout,
|
||||
args=(channel,),
|
||||
daemon=True,
|
||||
name=f"sse-fanout-{channel_key}",
|
||||
)
|
||||
channel.distributor.start()
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
def subscribe_fanout_queue(
|
||||
source_queue: queue.Queue,
|
||||
channel_key: str,
|
||||
source_timeout: float = 1.0,
|
||||
subscriber_queue_size: int = 500,
|
||||
) -> tuple[queue.Queue, Callable[[], None]]:
|
||||
"""
|
||||
Subscribe a client queue to a shared source queue fanout channel.
|
||||
|
||||
Returns:
|
||||
tuple: (subscriber_queue, unsubscribe_fn)
|
||||
"""
|
||||
channel = _ensure_fanout_channel(channel_key, source_queue, source_timeout)
|
||||
subscriber = queue.Queue(maxsize=subscriber_queue_size)
|
||||
|
||||
with channel.lock:
|
||||
channel.subscribers.add(subscriber)
|
||||
|
||||
def _unsubscribe() -> None:
|
||||
with channel.lock:
|
||||
channel.subscribers.discard(subscriber)
|
||||
|
||||
return subscriber, _unsubscribe
|
||||
|
||||
|
||||
def sse_stream_fanout(
|
||||
source_queue: queue.Queue,
|
||||
channel_key: str,
|
||||
timeout: float = 1.0,
|
||||
keepalive_interval: float = 30.0,
|
||||
stop_check: Callable[[], bool] | None = None,
|
||||
on_message: Callable[[dict[str, Any]], None] | None = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generate an SSE stream from a fanout channel backed by source_queue.
|
||||
"""
|
||||
subscriber, unsubscribe = subscribe_fanout_queue(
|
||||
source_queue=source_queue,
|
||||
channel_key=channel_key,
|
||||
source_timeout=timeout,
|
||||
)
|
||||
last_keepalive = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
if stop_check and stop_check():
|
||||
break
|
||||
|
||||
try:
|
||||
msg = subscriber.get(timeout=timeout)
|
||||
last_keepalive = time.time()
|
||||
if on_message and isinstance(msg, dict):
|
||||
try:
|
||||
on_message(msg)
|
||||
except Exception:
|
||||
pass
|
||||
yield format_sse(msg)
|
||||
except queue.Empty:
|
||||
now = time.time()
|
||||
if now - last_keepalive >= keepalive_interval:
|
||||
yield format_sse({'type': 'keepalive'})
|
||||
last_keepalive = now
|
||||
finally:
|
||||
unsubscribe()
|
||||
|
||||
|
||||
def sse_stream(
|
||||
data_queue: queue.Queue,
|
||||
timeout: float = 1.0,
|
||||
keepalive_interval: float = 30.0,
|
||||
stop_check: Callable[[], bool] | None = None,
|
||||
channel_key: str | None = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generate SSE stream from a queue.
|
||||
|
||||
Args:
|
||||
data_queue: Queue to read messages from
|
||||
timeout: Queue get timeout in seconds
|
||||
keepalive_interval: Seconds between keepalive messages
|
||||
stop_check: Optional callable that returns True to stop the stream
|
||||
|
||||
Yields:
|
||||
SSE formatted strings
|
||||
"""
|
||||
last_keepalive = time.time()
|
||||
|
||||
while True:
|
||||
# Check if we should stop
|
||||
if stop_check and stop_check():
|
||||
break
|
||||
|
||||
try:
|
||||
msg = data_queue.get(timeout=timeout)
|
||||
last_keepalive = time.time()
|
||||
yield format_sse(msg)
|
||||
except queue.Empty:
|
||||
# Send keepalive if enough time has passed
|
||||
now = time.time()
|
||||
if now - last_keepalive >= keepalive_interval:
|
||||
yield format_sse({'type': 'keepalive'})
|
||||
last_keepalive = now
|
||||
Generate SSE stream from a queue.
|
||||
|
||||
Args:
|
||||
data_queue: Queue to read messages from
|
||||
timeout: Queue get timeout in seconds
|
||||
keepalive_interval: Seconds between keepalive messages
|
||||
stop_check: Optional callable that returns True to stop the stream
|
||||
channel_key: Optional fanout key; defaults to stable queue id
|
||||
|
||||
Yields:
|
||||
SSE formatted strings
|
||||
"""
|
||||
key = channel_key or f"queue:{id(data_queue)}"
|
||||
yield from sse_stream_fanout(
|
||||
source_queue=data_queue,
|
||||
channel_key=key,
|
||||
timeout=timeout,
|
||||
keepalive_interval=keepalive_interval,
|
||||
stop_check=stop_check,
|
||||
)
|
||||
|
||||
|
||||
def format_sse(data: dict[str, Any] | str, event: str | None = None) -> str:
|
||||
|
||||
Reference in New Issue
Block a user