mirror of
https://github.com/smittix/intercept.git
synced 2026-04-24 22:59:59 -07:00
212 lines
5.9 KiB
Python
212 lines
5.9 KiB
Python
"""Server-Sent Events (SSE) utilities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import queue
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Generator
|
|
|
|
|
|
@dataclass
|
|
class _QueueFanoutChannel:
|
|
"""Internal fanout state for a source queue."""
|
|
source_queue: queue.Queue
|
|
source_timeout: float
|
|
subscribers: set[queue.Queue] = field(default_factory=set)
|
|
lock: threading.Lock = field(default_factory=threading.Lock)
|
|
distributor: threading.Thread | None = None
|
|
|
|
|
|
_fanout_channels: dict[str, _QueueFanoutChannel] = {}
|
|
_fanout_channels_lock = threading.Lock()
|
|
|
|
|
|
def _run_fanout(channel: _QueueFanoutChannel) -> None:
|
|
"""Drain source queue and fan out each message to all subscribers."""
|
|
while True:
|
|
try:
|
|
msg = channel.source_queue.get(timeout=channel.source_timeout)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
with channel.lock:
|
|
subscribers = tuple(channel.subscribers)
|
|
|
|
for subscriber in subscribers:
|
|
try:
|
|
subscriber.put_nowait(msg)
|
|
except queue.Full:
|
|
# Drop oldest frame for this subscriber and retry once.
|
|
try:
|
|
subscriber.get_nowait()
|
|
subscriber.put_nowait(msg)
|
|
except (queue.Empty, queue.Full):
|
|
continue
|
|
|
|
|
|
def _ensure_fanout_channel(
|
|
channel_key: str,
|
|
source_queue: queue.Queue,
|
|
source_timeout: float,
|
|
) -> _QueueFanoutChannel:
|
|
"""Get/create a fanout channel and ensure distributor thread is running."""
|
|
with _fanout_channels_lock:
|
|
channel = _fanout_channels.get(channel_key)
|
|
if channel is None:
|
|
channel = _QueueFanoutChannel(source_queue=source_queue, source_timeout=source_timeout)
|
|
_fanout_channels[channel_key] = channel
|
|
|
|
if channel.distributor is None or not channel.distributor.is_alive():
|
|
channel.distributor = threading.Thread(
|
|
target=_run_fanout,
|
|
args=(channel,),
|
|
daemon=True,
|
|
name=f"sse-fanout-{channel_key}",
|
|
)
|
|
channel.distributor.start()
|
|
|
|
return channel
|
|
|
|
|
|
def subscribe_fanout_queue(
|
|
source_queue: queue.Queue,
|
|
channel_key: str,
|
|
source_timeout: float = 1.0,
|
|
subscriber_queue_size: int = 500,
|
|
) -> tuple[queue.Queue, Callable[[], None]]:
|
|
"""
|
|
Subscribe a client queue to a shared source queue fanout channel.
|
|
|
|
Returns:
|
|
tuple: (subscriber_queue, unsubscribe_fn)
|
|
"""
|
|
channel = _ensure_fanout_channel(channel_key, source_queue, source_timeout)
|
|
subscriber = queue.Queue(maxsize=subscriber_queue_size)
|
|
|
|
with channel.lock:
|
|
channel.subscribers.add(subscriber)
|
|
|
|
def _unsubscribe() -> None:
|
|
with channel.lock:
|
|
channel.subscribers.discard(subscriber)
|
|
|
|
return subscriber, _unsubscribe
|
|
|
|
|
|
def sse_stream_fanout(
|
|
source_queue: queue.Queue,
|
|
channel_key: str,
|
|
timeout: float = 1.0,
|
|
keepalive_interval: float = 30.0,
|
|
stop_check: Callable[[], bool] | None = None,
|
|
on_message: Callable[[dict[str, Any]], None] | None = None,
|
|
) -> Generator[str, None, None]:
|
|
"""
|
|
Generate an SSE stream from a fanout channel backed by source_queue.
|
|
"""
|
|
subscriber, unsubscribe = subscribe_fanout_queue(
|
|
source_queue=source_queue,
|
|
channel_key=channel_key,
|
|
source_timeout=timeout,
|
|
)
|
|
last_keepalive = time.time()
|
|
|
|
try:
|
|
while True:
|
|
if stop_check and stop_check():
|
|
break
|
|
|
|
try:
|
|
msg = subscriber.get(timeout=timeout)
|
|
last_keepalive = time.time()
|
|
if on_message and isinstance(msg, dict):
|
|
try:
|
|
on_message(msg)
|
|
except Exception:
|
|
pass
|
|
yield format_sse(msg)
|
|
except queue.Empty:
|
|
now = time.time()
|
|
if now - last_keepalive >= keepalive_interval:
|
|
yield format_sse({'type': 'keepalive'})
|
|
last_keepalive = now
|
|
finally:
|
|
unsubscribe()
|
|
|
|
|
|
def sse_stream(
|
|
data_queue: queue.Queue,
|
|
timeout: float = 1.0,
|
|
keepalive_interval: float = 30.0,
|
|
stop_check: Callable[[], bool] | None = None,
|
|
channel_key: str | None = None,
|
|
) -> Generator[str, None, None]:
|
|
"""
|
|
Generate SSE stream from a queue.
|
|
|
|
Args:
|
|
data_queue: Queue to read messages from
|
|
timeout: Queue get timeout in seconds
|
|
keepalive_interval: Seconds between keepalive messages
|
|
stop_check: Optional callable that returns True to stop the stream
|
|
channel_key: Optional fanout key; defaults to stable queue id
|
|
|
|
Yields:
|
|
SSE formatted strings
|
|
"""
|
|
key = channel_key or f"queue:{id(data_queue)}"
|
|
yield from sse_stream_fanout(
|
|
source_queue=data_queue,
|
|
channel_key=key,
|
|
timeout=timeout,
|
|
keepalive_interval=keepalive_interval,
|
|
stop_check=stop_check,
|
|
)
|
|
|
|
|
|
def format_sse(data: dict[str, Any] | str, event: str | None = None) -> str:
|
|
"""
|
|
Format data as SSE message.
|
|
|
|
Args:
|
|
data: Data to send (will be JSON encoded if dict)
|
|
event: Optional event name
|
|
|
|
Returns:
|
|
SSE formatted string
|
|
"""
|
|
if isinstance(data, dict):
|
|
data = json.dumps(data)
|
|
|
|
lines = []
|
|
if event:
|
|
lines.append(f"event: {event}")
|
|
lines.append(f"data: {data}")
|
|
lines.append("")
|
|
lines.append("")
|
|
|
|
return '\n'.join(lines)
|
|
|
|
|
|
def clear_queue(q: queue.Queue) -> int:
|
|
"""
|
|
Clear all items from a queue.
|
|
|
|
Args:
|
|
q: Queue to clear
|
|
|
|
Returns:
|
|
Number of items cleared
|
|
"""
|
|
count = 0
|
|
while True:
|
|
try:
|
|
q.get_nowait()
|
|
count += 1
|
|
except queue.Empty:
|
|
break
|
|
return count
|