diff --git a/rrcd/router.py b/rrcd/router.py new file mode 100644 index 0000000..5087e20 --- /dev/null +++ b/rrcd/router.py @@ -0,0 +1,825 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import RNS + +from .codec import decode, encode +from .constants import ( + B_HELLO_CAPS, + B_HELLO_NICK_LEGACY, + B_RES_ENCODING, + B_RES_ID, + B_RES_KIND, + B_RES_SHA256, + B_RES_SIZE, + K_BODY, + K_NICK, + K_ROOM, + K_SRC, + K_T, + T_HELLO, + T_JOIN, + T_JOINED, + T_MSG, + T_NOTICE, + T_PART, + T_PARTED, + T_PING, + T_PONG, + T_RESOURCE_ENVELOPE, +) +from .envelope import make_envelope, validate_envelope +from .util import normalize_nick + +if TYPE_CHECKING: + from .service import HubService + + +class MessageRouter: + """ + Handles message routing and dispatching for the RRC hub. + + This class is responsible for: + - Decoding and validating incoming packets + - Dispatching messages by type (HELLO, JOIN, PART, MSG, NOTICE, PING, etc.) + - Forwarding messages to appropriate rooms/recipients + - Rate limiting + - Protocol validation + """ + + def __init__(self, hub: HubService) -> None: + self.hub = hub + self.log = logging.getLogger("rrcd.router") + + def route_packet( + self, + link: RNS.Link, + data: bytes, + outgoing: list[tuple[RNS.Link, bytes]], + ) -> None: + """ + Main entry point for routing an incoming packet. + + This method should be called with the state lock held. + """ + sess = self.hub.sessions.get(link) + if sess is None: + return + + self.hub._inc("pkts_in") + self.hub._inc("bytes_in", len(data)) + + peer_hash = sess.get("peer") + if peer_hash is None: + ri = link.get_remote_identity() + if ri is None: + # Per spec: the Link is the handshake. Ignore all traffic until it + # is identified. + return + peer_hash = ri.hash + sess["peer"] = peer_hash + + if not self.hub._refill_and_take(link, 1.0): + self.hub._inc("rate_limited") + if self.log.isEnabledFor(logging.DEBUG): + self.log.debug( + "Rate limited peer=%s link_id=%s", + self.hub._fmt_hash(peer_hash), + self.hub._fmt_link_id(link), + ) + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, link, src=self.hub.identity.hash, text="rate limited" + ) + return + + try: + env = decode(data) + validate_envelope(env) + except Exception as e: + self.hub._inc("pkts_bad") + self.log.debug( + "Bad packet peer=%s link_id=%s bytes=%s err=%s", + self.hub._fmt_hash(peer_hash), + self.hub._fmt_link_id(link), + len(data), + e, + ) + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, link, src=self.hub.identity.hash, text=f"bad message: {e}" + ) + return + + t = env.get(K_T) + room = env.get(K_ROOM) + body = env.get(K_BODY) + + if self.log.isEnabledFor(logging.DEBUG): + body_len = None + if isinstance(body, (bytes, bytearray)): + body_len = len(body) + elif isinstance(body, str): + body_len = len(body) + self.log.debug( + "RX peer=%s link_id=%s t=%s room=%r bytes=%s body_type=%s body_len=%s", + self.hub._fmt_hash(peer_hash), + self.hub._fmt_link_id(link), + t, + room, + len(data), + type(body).__name__, + body_len, + ) + + # Dispatch by message type + if t == T_PONG: + self._handle_pong(link, sess) + elif t == T_RESOURCE_ENVELOPE: + self._handle_resource_envelope(link, sess, env, outgoing) + elif not sess["welcomed"]: + self._handle_pre_welcome(link, sess, peer_hash, env, outgoing) + elif t == T_HELLO: + self._handle_re_hello(link, sess, peer_hash, env, outgoing) + elif t == T_JOIN: + self._handle_join(link, sess, peer_hash, env, outgoing) + elif t == T_PART: + self._handle_part(link, sess, peer_hash, env, outgoing) + elif t in (T_MSG, T_NOTICE): + self._handle_message(link, sess, peer_hash, env, outgoing) + elif t == T_PING: + self._handle_ping(link, env, outgoing) + + def _handle_pong(self, link: RNS.Link, sess: dict[str, Any]) -> None: + """Handle PONG message.""" + self.hub._inc("pongs_in") + sess["awaiting_pong"] = None + + def _handle_resource_envelope( + self, + link: RNS.Link, + sess: dict[str, Any], + env: dict, + outgoing: list[tuple[RNS.Link, bytes]], + ) -> None: + """Handle RESOURCE_ENVELOPE message.""" + room = env.get(K_ROOM) + body = env.get(K_BODY) + + if not self.hub.config.enable_resource_transfer: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="resource transfer disabled", + room=room, + ) + return + + if not isinstance(body, dict): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="invalid resource envelope body", + room=room, + ) + return + + rid = body.get(B_RES_ID) + kind = body.get(B_RES_KIND) + size = body.get(B_RES_SIZE) + sha256 = body.get(B_RES_SHA256) + encoding = body.get(B_RES_ENCODING) + + # Validate required fields + if not isinstance(rid, (bytes, bytearray)): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="resource envelope missing id", + room=room, + ) + return + + if not isinstance(kind, str) or not kind: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="resource envelope missing kind", + room=room, + ) + return + + if not isinstance(size, int) or size < 0: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="resource envelope invalid size", + room=room, + ) + return + + # Check size limit + if size > self.hub.config.max_resource_bytes: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text=f"resource too large: {size} > {self.hub.config.max_resource_bytes}", + room=room, + ) + return + + # Validate optional fields + if sha256 is not None and not isinstance(sha256, (bytes, bytearray)): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="resource envelope invalid sha256", + room=room, + ) + return + + if encoding is not None and not isinstance(encoding, str): + encoding = None + + # Add expectation + if not self.hub._add_resource_expectation( + link, + rid=bytes(rid), + kind=kind, + size=size, + sha256=bytes(sha256) if sha256 else None, + encoding=encoding, + room=room, + ): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="too many pending resource expectations", + room=room, + ) + + def _handle_pre_welcome( + self, + link: RNS.Link, + sess: dict[str, Any], + peer_hash: bytes, + env: dict, + outgoing: list[tuple[RNS.Link, bytes]], + ) -> None: + """Handle messages before WELCOME (only HELLO is allowed).""" + t = env.get(K_T) + nick = env.get(K_NICK) + body = env.get(K_BODY) + + if t != T_HELLO: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, link, src=self.hub.identity.hash, text="send HELLO first" + ) + return + + old_nick = sess.get("nick") + new_nick = None + + if isinstance(nick, str): + n = normalize_nick(nick, max_chars=self.hub.config.nick_max_chars) + if n is not None: + new_nick = n + sess["nick"] = n + + if isinstance(body, dict): + sess["peer_caps"] = self._extract_caps(body) + + # Back-compat: if a legacy client put nick in HELLO body, accept it. + if new_nick is None: + legacy_nick = body.get(B_HELLO_NICK_LEGACY) + n2 = normalize_nick( + legacy_nick, max_chars=self.hub.config.nick_max_chars + ) + if n2 is not None: + new_nick = n2 + sess["nick"] = n2 + + # Update nick index if nick changed + if old_nick != new_nick: + self.hub._update_nick_index(link, old_nick, new_nick) + + self.log.info( + "HELLO peer=%s nick=%r link_id=%s", + self.hub._fmt_hash(peer_hash), + sess.get("nick"), + self.hub._fmt_link_id(link), + ) + + sess["welcomed"] = True + self.hub._queue_welcome( + outgoing, + link, + peer_hash=peer_hash, + motd=self.hub.config.greeting, + ) + + def _handle_re_hello( + self, + link: RNS.Link, + sess: dict[str, Any], + peer_hash: bytes, + env: dict, + outgoing: list[tuple[RNS.Link, bytes]], + ) -> None: + """Handle re-authentication (HELLO after already welcomed).""" + nick = env.get(K_NICK) + body = env.get(K_BODY) + + if self.hub.identity is None: + return + + # Reset session state and process as new HELLO + old_nick = sess.get("nick") + old_rooms = set(sess.get("rooms", set())) + sess["welcomed"] = False + sess["rooms"] = set() + sess["nick"] = None + sess["peer_caps"] = {} + + # Remove this link from all room membership sets and prune empties. + for r in old_rooms: + self.hub.rooms.get(r, set()).discard(link) + if r in self.hub.rooms and not self.hub.rooms[r]: + self.hub.rooms.pop(r, None) + st = self.hub._room_state_get(r) + if st is not None and not st.get("registered"): + self.hub._room_state.pop(r, None) + + new_nick = None + + # Process the HELLO message + if isinstance(nick, str): + n = normalize_nick(nick, max_chars=self.hub.config.nick_max_chars) + if n is not None: + new_nick = n + sess["nick"] = n + + if isinstance(body, dict): + sess["peer_caps"] = self._extract_caps(body) + if new_nick is None: + legacy_nick = body.get(B_HELLO_NICK_LEGACY) + n2 = normalize_nick( + legacy_nick, max_chars=self.hub.config.nick_max_chars + ) + if n2 is not None: + new_nick = n2 + sess["nick"] = n2 + + # Update nick index if nick changed + if old_nick != new_nick: + self.hub._update_nick_index(link, old_nick, new_nick) + + self.log.info( + "Re-HELLO peer=%s nick=%r link_id=%s", + self.hub._fmt_hash(peer_hash), + sess.get("nick"), + self.hub._fmt_link_id(link), + ) + + sess["welcomed"] = True + self.hub._queue_welcome( + outgoing, + link, + peer_hash=peer_hash, + motd=self.hub.config.greeting, + ) + + def _handle_join( + self, + link: RNS.Link, + sess: dict[str, Any], + peer_hash: bytes, + env: dict, + outgoing: list[tuple[RNS.Link, bytes]], + ) -> None: + """Handle JOIN message.""" + room = env.get(K_ROOM) + body = env.get(K_BODY) + + self.hub._inc("joins") + if not isinstance(room, str) or not room: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="JOIN requires room name", + ) + return + + if len(sess["rooms"]) >= int(self.hub.config.max_rooms_per_session): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, link, src=self.hub.identity.hash, text="too many rooms" + ) + return + + try: + r = self.hub._norm_room(room) + except Exception as e: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, link, src=self.hub.identity.hash, text=str(e) + ) + return + + # If room is registered, load its state now. + if r in self.hub._room_registry: + self.hub._room_state_ensure(r) + + st = self.hub._room_state_ensure(r) + + # +i invite-only + if bool(st.get("invite_only", False)): + is_invited = self.hub._is_invited(st, peer_hash) + if not self.hub._is_room_op(r, peer_hash) and not is_invited: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="invite-only (+i)", + room=r, + ) + return + + # +k key/password (JOIN body must be the key string) + key = st.get("key") + if isinstance(key, str) and key: + is_invited = self.hub._is_invited(st, peer_hash) + if not self.hub._is_room_op(r, peer_hash) and not is_invited: + provided = body if isinstance(body, str) else None + if provided != key: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="bad key (+k)", + room=r, + ) + return + + # Room bans are room-local and apply to JOIN. + if self.hub._is_room_banned(r, peer_hash): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="banned from room", + room=r, + ) + return + + # If the room doesn't exist yet (in-memory), the first joiner is the founder. + if r not in self.hub.rooms: + self.hub.rooms[r] = set() + self.hub._room_state_ensure(r, founder=peer_hash) + + sess["rooms"].add(r) + self.hub.rooms.setdefault(r, set()).add(link) + + self.log.info( + "JOIN peer=%s nick=%r room=%s link_id=%s", + self.hub._fmt_hash(peer_hash), + sess.get("nick"), + r, + self.hub._fmt_link_id(link), + ) + + self.hub._touch_room(r) + + joined_body = None + if self.hub.config.include_joined_member_list: + members: list[bytes] = [] + for member_link in self.hub.rooms.get(r, set()): + s = self.hub.sessions.get(member_link) + ph = s.get("peer") if s else None + if isinstance(ph, (bytes, bytearray)): + members.append(bytes(ph)) + joined_body = members + + joined = make_envelope( + T_JOINED, src=self.hub.identity.hash, room=r, body=joined_body + ) + self.hub._queue_env(outgoing, link, joined) + + # Consume invite on successful join. + try: + inv = st.get("invited") + if isinstance(inv, dict) and peer_hash in inv: + inv.pop(peer_hash, None) + if bool(st.get("registered")): + self.hub._persist_room_state_to_registry(link, r) + except Exception: + pass + + try: + registered = bool(st.get("registered", False)) + topic = st.get("topic") if isinstance(st.get("topic"), str) else None + mode_txt = self.hub._room_mode_string(r) + topic_txt = topic if topic else "(none)" + reg_txt = "registered" if registered else "unregistered" + self.hub._emit_notice( + outgoing, + link, + r, + f"room {r}: {reg_txt}; mode={mode_txt}; topic={topic_txt}", + ) + except Exception: + pass + + def _handle_part( + self, + link: RNS.Link, + sess: dict[str, Any], + peer_hash: bytes, + env: dict, + outgoing: list[tuple[RNS.Link, bytes]], + ) -> None: + """Handle PART message.""" + room = env.get(K_ROOM) + + self.hub._inc("parts") + if not isinstance(room, str) or not room: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="PART requires room name", + ) + return + + try: + r = self.hub._norm_room(room) + except Exception as e: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, link, src=self.hub.identity.hash, text=str(e) + ) + return + + sess["rooms"].discard(r) + if r in self.hub.rooms: + self.hub.rooms[r].discard(link) + if not self.hub.rooms[r]: + self.hub.rooms.pop(r, None) + st = self.hub._room_state_get(r) + if st is not None: + self.hub._touch_room(r) + if st.get("registered"): + self.hub._persist_room_state_to_registry(link, r) + if st is not None and not st.get("registered"): + self.hub._room_state.pop(r, None) + + # Per spec: acknowledge PART with PARTED. + parted_body = None + if self.hub.config.include_joined_member_list: + members: list[bytes] = [] + for member_link in self.hub.rooms.get(r, set()): + s = self.hub.sessions.get(member_link) + ph = s.get("peer") if s else None + if isinstance(ph, (bytes, bytearray)): + members.append(bytes(ph)) + parted_body = members + + if self.hub.identity is not None: + parted = make_envelope( + T_PARTED, src=self.hub.identity.hash, room=r, body=parted_body + ) + self.hub._queue_env(outgoing, link, parted) + + self.log.info( + "PART peer=%s nick=%r room=%s link_id=%s", + self.hub._fmt_hash(peer_hash), + sess.get("nick"), + r, + self.hub._fmt_link_id(link), + ) + + def _handle_message( + self, + link: RNS.Link, + sess: dict[str, Any], + peer_hash: bytes, + env: dict, + outgoing: list[tuple[RNS.Link, bytes]], + ) -> None: + """Handle MSG and NOTICE messages.""" + t = env.get(K_T) + room = env.get(K_ROOM) + body = env.get(K_BODY) + + # Check for slash commands first, as they may not require a room. + # Per RRC spec, the room field is optional and may be empty. + if isinstance(body, str): + cmdline = body.strip() + if cmdline.startswith("/"): + # It's a slash command - attempt to handle it + if self.log.isEnabledFor(logging.DEBUG): + self.log.debug( + "Slash command peer=%s link_id=%s cmd=%r room=%r", + self.hub._fmt_hash(peer_hash), + self.hub._fmt_link_id(link), + cmdline, + room, + ) + handled = self.hub._handle_operator_command( + link, peer_hash=peer_hash, room=room, text=body, outgoing=outgoing + ) + if handled: + if self.log.isEnabledFor(logging.DEBUG): + self.log.debug( + "Slash command handled, queued=%d responses", + len(outgoing), + ) + return + # Unrecognized slash command - send error + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="unrecognized command", + room=room, + ) + return + + # NOTICE messages are informational/non-conversational and don't require a room. + # MSG messages require a room for delivery. + if t == T_MSG: + if not isinstance(room, str) or not room: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="message requires room name", + ) + return + elif t == T_NOTICE: + # NOTICE without a room is allowed - just don't forward it anywhere + if not isinstance(room, str) or not room: + return + + try: + r = self.hub._norm_room(room) + except Exception as e: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, link, src=self.hub.identity.hash, text=str(e) + ) + return + + if r not in sess["rooms"]: + # +n (no outside messages): when enabled, require membership. + # When disabled (-n), allow sending to existing/registered rooms. + st = None + if r in self.hub._room_registry: + st = self.hub._room_state_ensure(r) + elif r in self.hub.rooms: + st = self.hub._room_state_ensure(r) + + if st is None: + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="no such room", + room=r, + ) + return + + if bool(st.get("no_outside_msgs", False)): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="no outside messages (+n)", + room=r, + ) + return + + # Per-room moderation: bans and moderated mode. + if self.hub._is_room_banned(r, peer_hash): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="banned from room", + room=r, + ) + return + if self.hub._room_moderated(r) and not self.hub._is_room_voiced(r, peer_hash): + if self.hub.identity is not None: + self.hub._emit_error( + outgoing, + link, + src=self.hub.identity.hash, + text="room is moderated (+m)", + room=r, + ) + return + + if peer_hash is not None: + env[K_SRC] = ( + bytes(peer_hash) + if isinstance(peer_hash, (bytes, bytearray)) + else peer_hash + ) + env[K_ROOM] = r + + # Preserve the nickname from the incoming envelope if present. + # Fall back to session nickname (from HELLO) if client didn't provide one. + # This allows clients to update their nickname mid-session. + incoming_nick = env.get(K_NICK) + if incoming_nick is not None: + # Client provided a nickname in this message - validate and preserve it + n = normalize_nick(incoming_nick, max_chars=self.hub.config.nick_max_chars) + if n is not None: + # Update session nick and index if it changed + old_session_nick = sess.get("nick") + if old_session_nick != n: + sess["nick"] = n + self.hub._update_nick_index(link, old_session_nick, n) + env[K_NICK] = n + else: + # Invalid nickname provided - remove it + env.pop(K_NICK, None) + else: + # No nickname in message - use session nickname from HELLO if available + nick = sess.get("nick") + n = normalize_nick(nick, max_chars=self.hub.config.nick_max_chars) + if n is not None: + env[K_NICK] = n + + payload = encode(env) + for other in list(self.hub.rooms.get(r, set())): + self.hub._queue_payload(outgoing, other, payload) + + if self.log.isEnabledFor(logging.DEBUG): + self.log.debug( + "Forwarded t=%s peer=%s nick=%r room=%s recipients=%s body_type=%s", + t, + self.hub._fmt_hash(peer_hash), + sess.get("nick"), + r, + len(self.hub.rooms.get(r, set())), + type(body).__name__, + ) + + if t == T_MSG: + self.hub._inc("msgs_forwarded") + else: + self.hub._inc("notices_forwarded") + + def _handle_ping( + self, + link: RNS.Link, + env: dict, + outgoing: list[tuple[RNS.Link, bytes]], + ) -> None: + """Handle PING message.""" + body = env.get(K_BODY) + + self.hub._inc("pings_in") + if self.hub.identity is not None: + pong = make_envelope(T_PONG, src=self.hub.identity.hash, body=body) + self.hub._inc("pongs_out") + self.hub._queue_env(outgoing, link, pong) + + def _extract_caps(self, body: Any) -> dict[int, Any]: + """Extract capabilities from HELLO body.""" + if not isinstance(body, dict): + return {} + caps = body.get(B_HELLO_CAPS) + return caps if isinstance(caps, dict) else {} diff --git a/rrcd/service.py b/rrcd/service.py index 8547d52..23dcd7c 100644 --- a/rrcd/service.py +++ b/rrcd/service.py @@ -12,11 +12,9 @@ from typing import Any import RNS from . import __version__ -from .codec import decode, encode +from .codec import encode from .config import HubRuntimeConfig from .constants import ( - B_HELLO_CAPS, - B_HELLO_NICK_LEGACY, B_RES_ENCODING, B_RES_ID, B_RES_KIND, @@ -24,36 +22,20 @@ from .constants import ( B_RES_SIZE, B_WELCOME_HUB, B_WELCOME_VER, - K_BODY, - K_NICK, - K_ROOM, - K_SRC, - K_T, RES_KIND_BLOB, RES_KIND_MOTD, RES_KIND_NOTICE, T_ERROR, - T_HELLO, - T_JOIN, - T_JOINED, - T_MSG, T_NOTICE, - T_PART, - T_PARTED, T_PING, - T_PONG, T_RESOURCE_ENVELOPE, T_WELCOME, ) -from .envelope import make_envelope, validate_envelope +from .envelope import make_envelope from .logging_config import configure_logging -from .util import expand_path, normalize_nick - - -@dataclass -class _RateState: - tokens: float - last_refill: float +from .router import MessageRouter +from .session import SessionManager +from .util import expand_path @dataclass @@ -81,12 +63,31 @@ class HubService: self._shutdown = threading.Event() + # Message router for handling protocol messages + self.router = MessageRouter(self) + + # Session manager for connection lifecycle + self.session_manager = SessionManager(self) + + @property + def sessions(self) -> dict[RNS.Link, dict[str, Any]]: + """Delegate to session_manager.sessions for backward compatibility.""" + return self.session_manager.sessions + + @property + def _index_by_hash(self) -> dict[bytes, RNS.Link]: + """Delegate to session_manager for backward compatibility.""" + return self.session_manager._index_by_hash + + @property + def _index_by_nick(self) -> dict[str, set[RNS.Link]]: + """Delegate to session_manager for backward compatibility.""" + return self.session_manager._index_by_nick + self.identity: RNS.Identity | None = None self.destination: RNS.Destination | None = None self.rooms: dict[str, set[RNS.Link]] = {} - self.sessions: dict[RNS.Link, dict[str, Any]] = {} - self._rate: dict[RNS.Link, _RateState] = {} # Resource transfer state self._resource_expectations: dict[RNS.Link, dict[bytes, _ResourceExpectation]] = {} @@ -97,11 +98,6 @@ class HubService: self._trusted: set[bytes] = set() self._banned: set[bytes] = set() - # Secondary indexes for efficient link lookups (O(1) instead of O(n)). - # These are maintained alongside sessions and must stay in sync. - self._index_by_hash: dict[bytes, RNS.Link] = {} # identity hash -> link - self._index_by_nick: dict[str, set[RNS.Link]] = {} # normalized nick (lowercase) -> links - # Room state (hub-local conventions; no new on-wire message types). # _room_state holds active in-memory state (and registered state for empty rooms). # _room_registry holds registered rooms loaded from config. @@ -144,11 +140,7 @@ class HubService: "resource_bytes_received": 0, } - def _extract_caps(self, body: Any) -> dict[int, Any]: - if not isinstance(body, dict): - return {} - caps = body.get(B_HELLO_CAPS) - return caps if isinstance(caps, dict) else {} + def _fmt_hash(self, h: Any, *, prefix: int = 12) -> str: if isinstance(h, (bytes, bytearray)): @@ -277,19 +269,8 @@ class HubService: pass def _update_nick_index(self, link: RNS.Link, old_nick: str | None, new_nick: str | None) -> None: - """Update nick index when a nick changes. Must be called under _state_lock.""" - # Remove old nick mapping - if old_nick: - old_key = old_nick.strip().lower() - if old_key in self._index_by_nick: - self._index_by_nick[old_key].discard(link) - if not self._index_by_nick[old_key]: - self._index_by_nick.pop(old_key, None) - - # Add new nick mapping - if new_nick: - new_key = new_nick.strip().lower() - self._index_by_nick.setdefault(new_key, set()).add(link) + """Update nick index when a nick changes. Delegates to SessionManager.""" + self.session_manager.update_nick_index(link, old_nick, new_nick) # Resource transfer methods @@ -982,10 +963,8 @@ class HubService: self._shutdown.set() with self._state_lock: - links = list(self.sessions.keys()) - self.sessions.clear() + links = self.session_manager.clear_all() self.rooms.clear() - self._rate.clear() self._resource_expectations.clear() self._active_resources.clear() @@ -2056,13 +2035,10 @@ class HubService: uptime_s = (now_mono - started_mono) if started_mono is not None else 0.0 with self._state_lock: - sessions_total = len(self.sessions) - sessions_welcomed = sum( - 1 for s in self.sessions.values() if s.get("welcomed") - ) - sessions_identified = sum( - 1 for s in self.sessions.values() if s.get("peer") is not None - ) + session_stats = self.session_manager.get_stats() + sessions_total = session_stats["total"] + sessions_welcomed = session_stats["welcomed"] + sessions_identified = session_stats["identified"] rooms_total = len(self.rooms) memberships = sum(len(v) for v in self.rooms.values()) @@ -3194,19 +3170,7 @@ class HubService: def _on_link(self, link: RNS.Link) -> None: with self._state_lock: - self.sessions[link] = { - "welcomed": False, - "rooms": set(), - "peer": None, - "nick": None, - "peer_caps": {}, - "awaiting_pong": None, - } - - self._rate[link] = _RateState( - tokens=float(self.config.rate_limit_msgs_per_minute), - last_refill=time.monotonic(), - ) + self.session_manager.on_link_established(link) # Initialize resource tracking for this link self._resource_expectations[link] = {} @@ -3243,19 +3207,9 @@ class HubService: self, link: RNS.Link, identity: RNS.Identity | None ) -> None: banned = False + peer_hash = None with self._state_lock: - sess = self.sessions.get(link) - if sess is None: - return - - if identity is not None: - sess["peer"] = identity.hash - - peer_hash = sess.get("peer") - banned = ( - isinstance(peer_hash, (bytes, bytearray)) - and bytes(peer_hash) in self._banned - ) + banned, peer_hash = self.session_manager.on_remote_identified(link, identity) if banned: self.log.warning( @@ -3272,14 +3226,6 @@ class HubService: link.teardown() except Exception: pass - return - - if identity is not None: - self.log.info( - "Remote identified peer=%s link_id=%s", - self._fmt_hash(identity.hash), - self._fmt_link_id(link), - ) def _welcome(self, link: RNS.Link, sess: dict[str, Any]) -> None: if self.identity is None: @@ -3320,34 +3266,12 @@ class HubService: rooms_count = 0 with self._state_lock: - sess = self.sessions.pop(link, None) - self._rate.pop(link, None) - # Clean up resource state self._resource_expectations.pop(link, None) self._active_resources.pop(link, None) - if not sess: - return - - peer = sess.get("peer") - nick = sess.get("nick") - rooms_count = len(sess.get("rooms") or ()) - - # Clean up indexes - if isinstance(peer, (bytes, bytearray)): - self._index_by_hash.pop(bytes(peer), None) - - if nick: - self._update_nick_index(link, nick, None) - - for room in list(sess["rooms"]): - self.rooms.get(room, set()).discard(link) - if room in self.rooms and not self.rooms[room]: - self.rooms.pop(room, None) - st = self._room_state_get(room) - if st is not None and not st.get("registered"): - self._room_state.pop(room, None) + # Clean up session (also handles room cleanup) + peer, nick, rooms_count = self.session_manager.on_link_closed(link) self.log.info( "Link closed peer=%s nick=%r rooms=%s link_id=%s", @@ -3392,23 +3316,8 @@ class HubService: return r def _refill_and_take(self, link: RNS.Link, cost: float = 1.0) -> bool: - with self._state_lock: - state = self._rate.get(link) - if state is None: - return True - - now = time.monotonic() - per_min = float(max(1, int(self.config.rate_limit_msgs_per_minute))) - rate_per_s = per_min / 60.0 - elapsed = max(0.0, now - state.last_refill) - state.tokens = min(per_min, state.tokens + elapsed * rate_per_s) - state.last_refill = now - - if state.tokens < cost: - return False - - state.tokens -= cost - return True + """Token bucket rate limiting. Delegates to SessionManager.""" + return self.session_manager.refill_and_take(link, cost) def _on_packet(self, link: RNS.Link, data: bytes) -> None: # Packet callbacks can occur concurrently with other link callbacks and @@ -3450,677 +3359,12 @@ class HubService: data: bytes, outgoing: list[tuple[RNS.Link, bytes]], ) -> None: - sess = self.sessions.get(link) - if sess is None: - return - - self._inc("pkts_in") - self._inc("bytes_in", len(data)) - - peer_hash = sess.get("peer") - if peer_hash is None: - ri = link.get_remote_identity() - if ri is None: - # Per spec: the Link is the handshake. Ignore all traffic until it - # is identified. - return - peer_hash = ri.hash - sess["peer"] = peer_hash - - if not self._refill_and_take(link, 1.0): - self._inc("rate_limited") - if self.log.isEnabledFor(logging.DEBUG): - self.log.debug( - "Rate limited peer=%s link_id=%s", - self._fmt_hash(peer_hash), - self._fmt_link_id(link), - ) - if self.identity is not None: - self._emit_error( - outgoing, link, src=self.identity.hash, text="rate limited" - ) - return - - try: - env = decode(data) - validate_envelope(env) - except Exception as e: - self._inc("pkts_bad") - self.log.debug( - "Bad packet peer=%s link_id=%s bytes=%s err=%s", - self._fmt_hash(peer_hash), - self._fmt_link_id(link), - len(data), - e, - ) - if self.identity is not None: - self._emit_error( - outgoing, link, src=self.identity.hash, text=f"bad message: {e}" - ) - return - - t = env.get(K_T) - room = env.get(K_ROOM) - body = env.get(K_BODY) - nick = env.get(K_NICK) - - if self.log.isEnabledFor(logging.DEBUG): - body_len = None - if isinstance(body, (bytes, bytearray)): - body_len = len(body) - elif isinstance(body, str): - body_len = len(body) - self.log.debug( - "RX peer=%s link_id=%s t=%s room=%r bytes=%s body_type=%s body_len=%s", - self._fmt_hash(peer_hash), - self._fmt_link_id(link), - t, - room, - len(data), - type(body).__name__, - body_len, - ) - - if t == T_PONG: - self._inc("pongs_in") - sess["awaiting_pong"] = None - return - - if t == T_RESOURCE_ENVELOPE: - # Handle resource envelope announcement - if not self.config.enable_resource_transfer: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="resource transfer disabled", - room=room, - ) - return - - if not isinstance(body, dict): - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="invalid resource envelope body", - room=room, - ) - return - - rid = body.get(B_RES_ID) - kind = body.get(B_RES_KIND) - size = body.get(B_RES_SIZE) - sha256 = body.get(B_RES_SHA256) - encoding = body.get(B_RES_ENCODING) - - # Validate required fields - if not isinstance(rid, (bytes, bytearray)): - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="resource envelope missing id", - room=room, - ) - return - - if not isinstance(kind, str) or not kind: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="resource envelope missing kind", - room=room, - ) - return - - if not isinstance(size, int) or size < 0: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="resource envelope invalid size", - room=room, - ) - return - - # Check size limit - if size > self.config.max_resource_bytes: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text=f"resource too large: {size} > {self.config.max_resource_bytes}", - room=room, - ) - return - - # Validate optional fields - if sha256 is not None and not isinstance(sha256, (bytes, bytearray)): - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="resource envelope invalid sha256", - room=room, - ) - return - - if encoding is not None and not isinstance(encoding, str): - encoding = None - - # Add expectation - if not self._add_resource_expectation( - link, - rid=bytes(rid), - kind=kind, - size=size, - sha256=bytes(sha256) if sha256 else None, - encoding=encoding, - room=room, - ): - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="too many pending resource expectations", - room=room, - ) - return - - if not sess["welcomed"]: - if t != T_HELLO: - if self.identity is not None: - self._emit_error( - outgoing, link, src=self.identity.hash, text="send HELLO first" - ) - return - - old_nick = sess.get("nick") - new_nick = None - - if isinstance(nick, str): - n = normalize_nick(nick, max_chars=self.config.nick_max_chars) - if n is not None: - new_nick = n - sess["nick"] = n - - if isinstance(body, dict): - sess["peer_caps"] = self._extract_caps(body) - - # Back-compat: if a legacy client put nick in HELLO body, accept it. - if new_nick is None: - legacy_nick = body.get(B_HELLO_NICK_LEGACY) - n2 = normalize_nick( - legacy_nick, max_chars=self.config.nick_max_chars - ) - if n2 is not None: - new_nick = n2 - sess["nick"] = n2 - - # Update nick index if nick changed - if old_nick != new_nick: - self._update_nick_index(link, old_nick, new_nick) - - self.log.info( - "HELLO peer=%s nick=%r link_id=%s", - self._fmt_hash(peer_hash), - sess.get("nick"), - self._fmt_link_id(link), - ) - - sess["welcomed"] = True - self._queue_welcome( - outgoing, - link, - peer_hash=peer_hash, - motd=self.config.greeting, - ) - return - - if t == T_HELLO: - # Allow re-authentication if client reconnects with same Link ID - # (can happen when client restarts but RNS reuses deterministic link_id) - if self.identity is not None: - # Reset session state and process as new HELLO - old_nick = sess.get("nick") - old_rooms = set(sess.get("rooms", set())) - sess["welcomed"] = False - sess["rooms"] = set() - sess["nick"] = None - sess["peer_caps"] = {} - - # Remove this link from all room membership sets and prune empties. - for r in old_rooms: - self.rooms.get(r, set()).discard(link) - if r in self.rooms and not self.rooms[r]: - self.rooms.pop(r, None) - st = self._room_state_get(r) - if st is not None and not st.get("registered"): - self._room_state.pop(r, None) - - new_nick = None - - # Process the HELLO message - if isinstance(nick, str): - n = normalize_nick(nick, max_chars=self.config.nick_max_chars) - if n is not None: - new_nick = n - sess["nick"] = n - - if isinstance(body, dict): - sess["peer_caps"] = self._extract_caps(body) - if new_nick is None: - legacy_nick = body.get(B_HELLO_NICK_LEGACY) - n2 = normalize_nick( - legacy_nick, max_chars=self.config.nick_max_chars - ) - if n2 is not None: - new_nick = n2 - sess["nick"] = n2 - - # Update nick index if nick changed - if old_nick != new_nick: - self._update_nick_index(link, old_nick, new_nick) - - self.log.info( - "Re-HELLO peer=%s nick=%r link_id=%s", - self._fmt_hash(peer_hash), - sess.get("nick"), - self._fmt_link_id(link), - ) - - sess["welcomed"] = True - self._queue_welcome( - outgoing, - link, - peer_hash=peer_hash, - motd=self.config.greeting, - ) - return - - if t == T_JOIN: - self._inc("joins") - if not isinstance(room, str) or not room: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="JOIN requires room name", - ) - return - - if len(sess["rooms"]) >= int(self.config.max_rooms_per_session): - if self.identity is not None: - self._emit_error( - outgoing, link, src=self.identity.hash, text="too many rooms" - ) - return - - try: - r = self._norm_room(room) - except Exception as e: - if self.identity is not None: - self._emit_error( - outgoing, link, src=self.identity.hash, text=str(e) - ) - return - - # If room is registered, load its state now. - if r in self._room_registry: - self._room_state_ensure(r) - - st = self._room_state_ensure(r) - - # +i invite-only - if bool(st.get("invite_only", False)): - is_invited = self._is_invited(st, peer_hash) - if not self._is_room_op(r, peer_hash) and not is_invited: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="invite-only (+i)", - room=r, - ) - return - - # +k key/password (JOIN body must be the key string) - key = st.get("key") - if isinstance(key, str) and key: - is_invited = self._is_invited(st, peer_hash) - if not self._is_room_op(r, peer_hash) and not is_invited: - provided = body if isinstance(body, str) else None - if provided != key: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="bad key (+k)", - room=r, - ) - return - - # Room bans are room-local and apply to JOIN. - if self._is_room_banned(r, peer_hash): - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="banned from room", - room=r, - ) - return - - # If the room doesn't exist yet (in-memory), the first joiner is the founder. - if r not in self.rooms: - self.rooms[r] = set() - self._room_state_ensure(r, founder=peer_hash) - - sess["rooms"].add(r) - self.rooms.setdefault(r, set()).add(link) - - self.log.info( - "JOIN peer=%s nick=%r room=%s link_id=%s", - self._fmt_hash(peer_hash), - sess.get("nick"), - r, - self._fmt_link_id(link), - ) - - self._touch_room(r) - - joined_body = None - if self.config.include_joined_member_list: - members: list[bytes] = [] - for member_link in self.rooms.get(r, set()): - s = self.sessions.get(member_link) - ph = s.get("peer") if s else None - if isinstance(ph, (bytes, bytearray)): - members.append(bytes(ph)) - joined_body = members - - joined = make_envelope( - T_JOINED, src=self.identity.hash, room=r, body=joined_body - ) - self._queue_env(outgoing, link, joined) - - # Consume invite on successful join. - try: - inv = st.get("invited") - if isinstance(inv, dict) and peer_hash in inv: - inv.pop(peer_hash, None) - if bool(st.get("registered")): - self._persist_room_state_to_registry(link, r) - except Exception: - pass - - try: - registered = bool(st.get("registered", False)) - topic = st.get("topic") if isinstance(st.get("topic"), str) else None - mode_txt = self._room_mode_string(r) - topic_txt = topic if topic else "(none)" - reg_txt = "registered" if registered else "unregistered" - self._emit_notice( - outgoing, - link, - r, - f"room {r}: {reg_txt}; mode={mode_txt}; topic={topic_txt}", - ) - except Exception: - pass - return - - if t == T_PART: - self._inc("parts") - if not isinstance(room, str) or not room: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="PART requires room name", - ) - return - - try: - r = self._norm_room(room) - except Exception as e: - if self.identity is not None: - self._emit_error( - outgoing, link, src=self.identity.hash, text=str(e) - ) - return - - sess["rooms"].discard(r) - if r in self.rooms: - self.rooms[r].discard(link) - if not self.rooms[r]: - self.rooms.pop(r, None) - st = self._room_state_get(r) - if st is not None: - self._touch_room(r) - if st.get("registered"): - self._persist_room_state_to_registry(link, r) - if st is not None and not st.get("registered"): - self._room_state.pop(r, None) - - # Per spec: acknowledge PART with PARTED. - parted_body = None - if self.config.include_joined_member_list: - members: list[bytes] = [] - for member_link in self.rooms.get(r, set()): - s = self.sessions.get(member_link) - ph = s.get("peer") if s else None - if isinstance(ph, (bytes, bytearray)): - members.append(bytes(ph)) - parted_body = members - - if self.identity is not None: - parted = make_envelope( - T_PARTED, src=self.identity.hash, room=r, body=parted_body - ) - self._queue_env(outgoing, link, parted) - - self.log.info( - "PART peer=%s nick=%r room=%s link_id=%s", - self._fmt_hash(peer_hash), - sess.get("nick"), - r, - self._fmt_link_id(link), - ) - return - - if t in (T_MSG, T_NOTICE): - # Check for slash commands first, as they may not require a room. - # Per RRC spec, the room field is optional and may be empty. - if isinstance(body, str): - cmdline = body.strip() - if cmdline.startswith("/"): - # It's a slash command - attempt to handle it - if self.log.isEnabledFor(logging.DEBUG): - self.log.debug( - "Slash command peer=%s link_id=%s cmd=%r room=%r", - self._fmt_hash(peer_hash), - self._fmt_link_id(link), - cmdline, - room, - ) - handled = self._handle_operator_command( - link, peer_hash=peer_hash, room=room, text=body, outgoing=outgoing - ) - if handled: - if self.log.isEnabledFor(logging.DEBUG): - self.log.debug( - "Slash command handled, queued=%d responses", - len(outgoing), - ) - return - # Unrecognized slash command - send error - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="unrecognized command", - room=room, - ) - return - - # NOTICE messages are informational/non-conversational and don't require a room. - # MSG messages require a room for delivery. - if t == T_MSG: - if not isinstance(room, str) or not room: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="message requires room name", - ) - return - elif t == T_NOTICE: - # NOTICE without a room is allowed - just don't forward it anywhere - if not isinstance(room, str) or not room: - return - - try: - r = self._norm_room(room) - except Exception as e: - if self.identity is not None: - self._emit_error( - outgoing, link, src=self.identity.hash, text=str(e) - ) - return - - if r not in sess["rooms"]: - # +n (no outside messages): when enabled, require membership. - # When disabled (-n), allow sending to existing/registered rooms. - st = None - if r in self._room_registry: - st = self._room_state_ensure(r) - elif r in self.rooms: - st = self._room_state_ensure(r) - - if st is None: - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="no such room", - room=r, - ) - return - - if bool(st.get("no_outside_msgs", False)): - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="no outside messages (+n)", - room=r, - ) - return - - # Per-room moderation: bans and moderated mode. - if self._is_room_banned(r, peer_hash): - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="banned from room", - room=r, - ) - return - if self._room_moderated(r) and not self._is_room_voiced(r, peer_hash): - if self.identity is not None: - self._emit_error( - outgoing, - link, - src=self.identity.hash, - text="room is moderated (+m)", - room=r, - ) - return - - if peer_hash is not None: - env[K_SRC] = ( - bytes(peer_hash) - if isinstance(peer_hash, (bytes, bytearray)) - else peer_hash - ) - env[K_ROOM] = r - - # Preserve the nickname from the incoming envelope if present. - # Fall back to session nickname (from HELLO) if client didn't provide one. - # This allows clients to update their nickname mid-session. - incoming_nick = env.get(K_NICK) - if incoming_nick is not None: - # Client provided a nickname in this message - validate and preserve it - n = normalize_nick(incoming_nick, max_chars=self.config.nick_max_chars) - if n is not None: - # Update session nick and index if it changed - old_session_nick = sess.get("nick") - if old_session_nick != n: - sess["nick"] = n - self._update_nick_index(link, old_session_nick, n) - env[K_NICK] = n - else: - # Invalid nickname provided - remove it - env.pop(K_NICK, None) - else: - # No nickname in message - use session nickname from HELLO if available - nick = sess.get("nick") - n = normalize_nick(nick, max_chars=self.config.nick_max_chars) - if n is not None: - env[K_NICK] = n - - payload = encode(env) - for other in list(self.rooms.get(r, set())): - self._queue_payload(outgoing, other, payload) - - if self.log.isEnabledFor(logging.DEBUG): - self.log.debug( - "Forwarded t=%s peer=%s nick=%r room=%s recipients=%s body_type=%s", - t, - self._fmt_hash(peer_hash), - sess.get("nick"), - r, - len(self.rooms.get(r, set())), - type(body).__name__, - ) - - if t == T_MSG: - self._inc("msgs_forwarded") - else: - self._inc("notices_forwarded") - return - - if t == T_PING: - self._inc("pings_in") - if self.identity is not None: - pong = make_envelope(T_PONG, src=self.identity.hash, body=body) - self._inc("pongs_out") - self._queue_env(outgoing, link, pong) - return - - return + """ + Handle incoming packet with state lock held. + + Delegates to MessageRouter for message routing and dispatching. + """ + self.router.route_packet(link, data, outgoing) def _ping_loop(self) -> None: while not self._shutdown.is_set(): diff --git a/rrcd/session.py b/rrcd/session.py new file mode 100644 index 0000000..4ed4aad --- /dev/null +++ b/rrcd/session.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import RNS + +if TYPE_CHECKING: + from .service import HubService + + +@dataclass +class _RateState: + """Token bucket state for rate limiting.""" + tokens: float + last_refill: float + + +class SessionManager: + """ + Manages session lifecycle for RRC hub connections. + + This class is responsible for: + - Session creation and initialization + - Session state management (nicknames, rooms, capabilities) + - Nickname indexing for efficient lookups + - Rate limiting with token bucket algorithm + - Session cleanup and teardown + - Remote identity tracking + """ + + def __init__(self, hub: HubService) -> None: + self.hub = hub + self.log = logging.getLogger("rrcd.session") + + # Session state storage (keyed by RNS.Link) + self.sessions: dict[RNS.Link, dict[str, Any]] = {} + + # Rate limiting state + self._rate: dict[RNS.Link, _RateState] = {} + + # Secondary indexes for efficient lookups + self._index_by_hash: dict[bytes, RNS.Link] = {} # identity hash -> link + self._index_by_nick: dict[str, set[RNS.Link]] = {} # normalized nick -> links + + def on_link_established(self, link: RNS.Link) -> None: + """ + Handle new link establishment. + + Creates session state and sets up callbacks. + Must be called with state lock held. + """ + self.sessions[link] = { + "welcomed": False, + "rooms": set(), + "peer": None, + "nick": None, + "peer_caps": {}, + "awaiting_pong": None, + } + + self._rate[link] = _RateState( + tokens=float(self.hub.config.rate_limit_msgs_per_minute), + last_refill=time.monotonic(), + ) + + self.log.info("Session created link_id=%s", self.hub._fmt_link_id(link)) + + def on_remote_identified( + self, link: RNS.Link, identity: RNS.Identity | None + ) -> tuple[bool, bytes | None]: + """ + Handle remote identity being established. + + Returns: + (is_banned, peer_hash) tuple + Must be called with state lock held. + """ + sess = self.sessions.get(link) + if sess is None: + return False, None + + if identity is not None: + peer_hash = identity.hash + sess["peer"] = peer_hash + + # Update hash index + self._index_by_hash[bytes(peer_hash)] = link + + # Check if banned + banned = bytes(peer_hash) in self.hub._banned + + if not banned: + self.log.info( + "Remote identified peer=%s link_id=%s", + self.hub._fmt_hash(peer_hash), + self.hub._fmt_link_id(link), + ) + + return banned, peer_hash + + return False, None + + def on_link_closed(self, link: RNS.Link) -> tuple[bytes | None, str | None, int]: + """ + Handle link closure and cleanup. + + Returns: + (peer_hash, nick, rooms_count) for logging + Must be called with state lock held. + """ + sess = self.sessions.pop(link, None) + self._rate.pop(link, None) + + if not sess: + return None, None, 0 + + peer = sess.get("peer") + nick = sess.get("nick") + rooms_count = len(sess.get("rooms") or ()) + + # Clean up indexes + if isinstance(peer, (bytes, bytearray)): + self._index_by_hash.pop(bytes(peer), None) + + if nick: + self.update_nick_index(link, nick, None) + + # Clean up room memberships + for room in list(sess["rooms"]): + self.hub.rooms.get(room, set()).discard(link) + if room in self.hub.rooms and not self.hub.rooms[room]: + self.hub.rooms.pop(room, None) + st = self.hub._room_state_get(room) + if st is not None and not st.get("registered"): + self.hub._room_state.pop(room, None) + + return peer, nick, rooms_count + + def update_nick_index( + self, link: RNS.Link, old_nick: str | None, new_nick: str | None + ) -> None: + """ + Update nickname index when a nick changes. + + Must be called with state lock held. + """ + # Remove old nick mapping + if old_nick: + old_key = old_nick.strip().lower() + if old_key in self._index_by_nick: + self._index_by_nick[old_key].discard(link) + if not self._index_by_nick[old_key]: + self._index_by_nick.pop(old_key, None) + + # Add new nick mapping + if new_nick: + new_key = new_nick.strip().lower() + self._index_by_nick.setdefault(new_key, set()).add(link) + + def refill_and_take(self, link: RNS.Link, cost: float = 1.0) -> bool: + """ + Token bucket rate limiting. + + Refills tokens based on elapsed time and attempts to take `cost` tokens. + Returns True if tokens were available and taken, False if rate limited. + + Must be called with state lock held. + """ + state = self._rate.get(link) + if state is None: + return True + + now = time.monotonic() + per_min = float(max(1, int(self.hub.config.rate_limit_msgs_per_minute))) + rate_per_s = per_min / 60.0 + elapsed = max(0.0, now - state.last_refill) + state.tokens = min(per_min, state.tokens + elapsed * rate_per_s) + state.last_refill = now + + if state.tokens < cost: + return False + + state.tokens -= cost + return True + + def get_session(self, link: RNS.Link) -> dict[str, Any] | None: + """Get session state for a link.""" + return self.sessions.get(link) + + def get_link_by_hash(self, peer_hash: bytes) -> RNS.Link | None: + """Look up link by peer identity hash (O(1)).""" + return self._index_by_hash.get(bytes(peer_hash)) + + def get_links_by_nick(self, nick: str) -> set[RNS.Link]: + """Look up links by normalized nickname (O(1)).""" + key = nick.strip().lower() + return self._index_by_nick.get(key, set()).copy() + + def clear_all(self) -> list[RNS.Link]: + """ + Clear all sessions and return list of links for teardown. + + Must be called with state lock held. + """ + links = list(self.sessions.keys()) + self.sessions.clear() + self._rate.clear() + self._index_by_hash.clear() + self._index_by_nick.clear() + return links + + def get_stats(self) -> dict[str, Any]: + """Get session statistics for monitoring.""" + total = len(self.sessions) + welcomed = sum(1 for s in self.sessions.values() if s.get("welcomed")) + identified = sum(1 for s in self.sessions.values() if s.get("peer") is not None) + + return { + "total": total, + "welcomed": welcomed, + "identified": identified, + "indexed_by_hash": len(self._index_by_hash), + "indexed_by_nick": len(self._index_by_nick), + }