Files
intercept/utils/sse.py
2026-02-19 20:46:28 +00:00

212 lines
5.9 KiB
Python

"""Server-Sent Events (SSE) utilities."""
from __future__ import annotations
import json
import queue
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Generator
@dataclass
class _QueueFanoutChannel:
"""Internal fanout state for a source queue."""
source_queue: queue.Queue
source_timeout: float
subscribers: set[queue.Queue] = field(default_factory=set)
lock: threading.Lock = field(default_factory=threading.Lock)
distributor: threading.Thread | None = None
_fanout_channels: dict[str, _QueueFanoutChannel] = {}
_fanout_channels_lock = threading.Lock()
def _run_fanout(channel: _QueueFanoutChannel) -> None:
"""Drain source queue and fan out each message to all subscribers."""
while True:
try:
msg = channel.source_queue.get(timeout=channel.source_timeout)
except queue.Empty:
continue
with channel.lock:
subscribers = tuple(channel.subscribers)
for subscriber in subscribers:
try:
subscriber.put_nowait(msg)
except queue.Full:
# Drop oldest frame for this subscriber and retry once.
try:
subscriber.get_nowait()
subscriber.put_nowait(msg)
except (queue.Empty, queue.Full):
continue
def _ensure_fanout_channel(
channel_key: str,
source_queue: queue.Queue,
source_timeout: float,
) -> _QueueFanoutChannel:
"""Get/create a fanout channel and ensure distributor thread is running."""
with _fanout_channels_lock:
channel = _fanout_channels.get(channel_key)
if channel is None:
channel = _QueueFanoutChannel(source_queue=source_queue, source_timeout=source_timeout)
_fanout_channels[channel_key] = channel
if channel.distributor is None or not channel.distributor.is_alive():
channel.distributor = threading.Thread(
target=_run_fanout,
args=(channel,),
daemon=True,
name=f"sse-fanout-{channel_key}",
)
channel.distributor.start()
return channel
def subscribe_fanout_queue(
source_queue: queue.Queue,
channel_key: str,
source_timeout: float = 1.0,
subscriber_queue_size: int = 500,
) -> tuple[queue.Queue, Callable[[], None]]:
"""
Subscribe a client queue to a shared source queue fanout channel.
Returns:
tuple: (subscriber_queue, unsubscribe_fn)
"""
channel = _ensure_fanout_channel(channel_key, source_queue, source_timeout)
subscriber = queue.Queue(maxsize=subscriber_queue_size)
with channel.lock:
channel.subscribers.add(subscriber)
def _unsubscribe() -> None:
with channel.lock:
channel.subscribers.discard(subscriber)
return subscriber, _unsubscribe
def sse_stream_fanout(
source_queue: queue.Queue,
channel_key: str,
timeout: float = 1.0,
keepalive_interval: float = 30.0,
stop_check: Callable[[], bool] | None = None,
on_message: Callable[[dict[str, Any]], None] | None = None,
) -> Generator[str, None, None]:
"""
Generate an SSE stream from a fanout channel backed by source_queue.
"""
subscriber, unsubscribe = subscribe_fanout_queue(
source_queue=source_queue,
channel_key=channel_key,
source_timeout=timeout,
)
last_keepalive = time.time()
try:
while True:
if stop_check and stop_check():
break
try:
msg = subscriber.get(timeout=timeout)
last_keepalive = time.time()
if on_message and isinstance(msg, dict):
try:
on_message(msg)
except Exception:
pass
yield format_sse(msg)
except queue.Empty:
now = time.time()
if now - last_keepalive >= keepalive_interval:
yield format_sse({'type': 'keepalive'})
last_keepalive = now
finally:
unsubscribe()
def sse_stream(
data_queue: queue.Queue,
timeout: float = 1.0,
keepalive_interval: float = 30.0,
stop_check: Callable[[], bool] | None = None,
channel_key: str | None = None,
) -> Generator[str, None, None]:
"""
Generate SSE stream from a queue.
Args:
data_queue: Queue to read messages from
timeout: Queue get timeout in seconds
keepalive_interval: Seconds between keepalive messages
stop_check: Optional callable that returns True to stop the stream
channel_key: Optional fanout key; defaults to stable queue id
Yields:
SSE formatted strings
"""
key = channel_key or f"queue:{id(data_queue)}"
yield from sse_stream_fanout(
source_queue=data_queue,
channel_key=key,
timeout=timeout,
keepalive_interval=keepalive_interval,
stop_check=stop_check,
)
def format_sse(data: dict[str, Any] | str, event: str | None = None) -> str:
"""
Format data as SSE message.
Args:
data: Data to send (will be JSON encoded if dict)
event: Optional event name
Returns:
SSE formatted string
"""
if isinstance(data, dict):
data = json.dumps(data)
lines = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {data}")
lines.append("")
lines.append("")
return '\n'.join(lines)
def clear_queue(q: queue.Queue) -> int:
"""
Clear all items from a queue.
Args:
q: Queue to clear
Returns:
Number of items cleared
"""
count = 0
while True:
try:
q.get_nowait()
count += 1
except queue.Empty:
break
return count