mirror of
https://github.com/kc1awv/rrcd.git
synced 2026-06-08 14:11:53 -07:00
consolidate config management
This commit is contained in:
+129
-1
@@ -1,6 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import asdict, dataclass, replace
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .service import HubService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -36,3 +42,125 @@ class HubRuntimeConfig:
|
||||
log_file: str | None = None
|
||||
log_format: str = "%(asctime)s %(levelname)s %(name)s[%(threadName)s]: %(message)s"
|
||||
log_datefmt: str | None = None
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""
|
||||
Manages hub configuration loading, reloading, and persistence.
|
||||
|
||||
Handles:
|
||||
- Loading TOML configuration files
|
||||
- Applying configuration updates
|
||||
- Reloading configuration at runtime
|
||||
- Config diffing and comparison
|
||||
- Config file path resolution
|
||||
"""
|
||||
|
||||
def __init__(self, hub: HubService) -> None:
|
||||
self.hub = hub
|
||||
self.log = hub.log
|
||||
self._write_lock = threading.Lock()
|
||||
|
||||
def load_toml(self, path: str) -> dict:
|
||||
"""Load a TOML file and return its contents as a dictionary."""
|
||||
import tomllib
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
def apply_config_data(
|
||||
self, base: HubRuntimeConfig, data: dict
|
||||
) -> HubRuntimeConfig:
|
||||
"""Apply configuration data from TOML to a runtime config instance."""
|
||||
hub = data.get("hub") if isinstance(data, dict) else None
|
||||
if isinstance(hub, dict):
|
||||
data = {**data, **hub}
|
||||
|
||||
log_table = data.get("logging") if isinstance(data, dict) else None
|
||||
if isinstance(log_table, dict):
|
||||
mapped: dict[str, object] = {}
|
||||
if "level" in log_table:
|
||||
mapped["log_level"] = log_table.get("level")
|
||||
if "rns_level" in log_table:
|
||||
mapped["log_rns_level"] = log_table.get("rns_level")
|
||||
if "console" in log_table:
|
||||
mapped["log_console"] = log_table.get("console")
|
||||
if "file" in log_table:
|
||||
mapped["log_file"] = log_table.get("file")
|
||||
if "format" in log_table:
|
||||
mapped["log_format"] = log_table.get("format")
|
||||
if "datefmt" in log_table:
|
||||
mapped["log_datefmt"] = log_table.get("datefmt")
|
||||
data = {**data, **mapped}
|
||||
|
||||
allowed = set(asdict(base).keys())
|
||||
# This identifies where to reload from; do not let the file override it.
|
||||
allowed.discard("config_path")
|
||||
|
||||
updates = {k: v for k, v in data.items() if k in allowed}
|
||||
|
||||
for list_key in ("trusted_identities", "banned_identities"):
|
||||
if list_key in updates and isinstance(updates[list_key], list):
|
||||
updates[list_key] = tuple(str(x) for x in updates[list_key])
|
||||
|
||||
if "announce" in data and "announce_on_start" not in updates:
|
||||
try:
|
||||
updates["announce_on_start"] = bool(data["announce"])
|
||||
except Exception:
|
||||
pass
|
||||
if "configdir" in updates and updates["configdir"] == "":
|
||||
updates["configdir"] = None
|
||||
if "greeting" in updates and updates["greeting"] == "":
|
||||
updates["greeting"] = None
|
||||
if "log_file" in updates and updates["log_file"] == "":
|
||||
updates["log_file"] = None
|
||||
if "log_datefmt" in updates and updates["log_datefmt"] == "":
|
||||
updates["log_datefmt"] = None
|
||||
|
||||
return replace(base, **updates) if updates else base
|
||||
|
||||
def format_reload_value(self, v: Any) -> str:
|
||||
"""Format a config value for display in reload summaries."""
|
||||
if v is None:
|
||||
return "(none)"
|
||||
if isinstance(v, (bool, int, float)):
|
||||
return str(v)
|
||||
if isinstance(v, (tuple, list, set)):
|
||||
return f"len={len(v)}"
|
||||
s = str(v)
|
||||
s = " ".join(s.split())
|
||||
if len(s) > 80:
|
||||
s = s[:77] + "..."
|
||||
return s
|
||||
|
||||
def diff_config_summary(
|
||||
self, old: HubRuntimeConfig, new: HubRuntimeConfig
|
||||
) -> list[str]:
|
||||
"""Generate a summary of differences between two config instances."""
|
||||
old_d = asdict(old)
|
||||
new_d = asdict(new)
|
||||
old_d.pop("config_path", None)
|
||||
new_d.pop("config_path", None)
|
||||
|
||||
changed: list[str] = []
|
||||
for k in sorted(new_d.keys()):
|
||||
if old_d.get(k) == new_d.get(k):
|
||||
continue
|
||||
changed.append(
|
||||
f"{k}: {self.format_reload_value(old_d.get(k))} -> {self.format_reload_value(new_d.get(k))}"
|
||||
)
|
||||
return changed
|
||||
|
||||
def get_config_path_for_writes(self) -> str | None:
|
||||
"""Get the resolved config file path for write operations."""
|
||||
from .util import expand_path
|
||||
|
||||
p = self.hub.config.config_path
|
||||
if not p:
|
||||
return None
|
||||
return expand_path(str(p))
|
||||
|
||||
def get_write_lock(self) -> threading.Lock:
|
||||
"""Get the lock used for config file write operations."""
|
||||
return self._write_lock
|
||||
|
||||
+8
-100
@@ -13,7 +13,7 @@ import RNS
|
||||
from . import __version__
|
||||
from .codec import encode
|
||||
from .commands import CommandHandler
|
||||
from .config import HubRuntimeConfig
|
||||
from .config import ConfigManager, HubRuntimeConfig
|
||||
from .constants import (
|
||||
B_WELCOME_HUB,
|
||||
B_WELCOME_VER,
|
||||
@@ -67,6 +67,9 @@ class HubService:
|
||||
|
||||
# Trust manager for trusted/banned identities
|
||||
self.trust_manager = TrustManager(self)
|
||||
|
||||
# Config manager for configuration loading and reloading
|
||||
self.config_manager = ConfigManager(self)
|
||||
|
||||
self.identity: RNS.Identity | None = None
|
||||
self.destination: RNS.Destination | None = None
|
||||
@@ -77,8 +80,6 @@ class HubService:
|
||||
self._announce_thread: threading.Thread | None = None
|
||||
self._resource_cleanup_thread: threading.Thread | None = None
|
||||
|
||||
self._config_write_lock = threading.Lock()
|
||||
|
||||
|
||||
|
||||
def _fmt_hash(self, h: Any, *, prefix: int = 12) -> str:
|
||||
@@ -467,93 +468,6 @@ class HubService:
|
||||
raise ValueError(f"identity hash too short: {text!r}")
|
||||
return b
|
||||
|
||||
def _load_toml(self, path: str) -> dict:
|
||||
import tomllib
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
def _apply_config_data(
|
||||
self, base: HubRuntimeConfig, data: dict
|
||||
) -> HubRuntimeConfig:
|
||||
hub = data.get("hub") if isinstance(data, dict) else None
|
||||
if isinstance(hub, dict):
|
||||
data = {**data, **hub}
|
||||
|
||||
log_table = data.get("logging") if isinstance(data, dict) else None
|
||||
if isinstance(log_table, dict):
|
||||
mapped: dict[str, object] = {}
|
||||
if "level" in log_table:
|
||||
mapped["log_level"] = log_table.get("level")
|
||||
if "rns_level" in log_table:
|
||||
mapped["log_rns_level"] = log_table.get("rns_level")
|
||||
if "console" in log_table:
|
||||
mapped["log_console"] = log_table.get("console")
|
||||
if "file" in log_table:
|
||||
mapped["log_file"] = log_table.get("file")
|
||||
if "format" in log_table:
|
||||
mapped["log_format"] = log_table.get("format")
|
||||
if "datefmt" in log_table:
|
||||
mapped["log_datefmt"] = log_table.get("datefmt")
|
||||
data = {**data, **mapped}
|
||||
|
||||
allowed = set(asdict(base).keys())
|
||||
# This identifies where to reload from; do not let the file override it.
|
||||
allowed.discard("config_path")
|
||||
|
||||
updates = {k: v for k, v in data.items() if k in allowed}
|
||||
|
||||
for list_key in ("trusted_identities", "banned_identities"):
|
||||
if list_key in updates and isinstance(updates[list_key], list):
|
||||
updates[list_key] = tuple(str(x) for x in updates[list_key])
|
||||
|
||||
if "announce" in data and "announce_on_start" not in updates:
|
||||
try:
|
||||
updates["announce_on_start"] = bool(data["announce"])
|
||||
except Exception:
|
||||
pass
|
||||
if "configdir" in updates and updates["configdir"] == "":
|
||||
updates["configdir"] = None
|
||||
if "greeting" in updates and updates["greeting"] == "":
|
||||
updates["greeting"] = None
|
||||
if "log_file" in updates and updates["log_file"] == "":
|
||||
updates["log_file"] = None
|
||||
if "log_datefmt" in updates and updates["log_datefmt"] == "":
|
||||
updates["log_datefmt"] = None
|
||||
|
||||
return replace(base, **updates) if updates else base
|
||||
|
||||
def _format_reload_value(self, v: Any) -> str:
|
||||
if v is None:
|
||||
return "(none)"
|
||||
if isinstance(v, (bool, int, float)):
|
||||
return str(v)
|
||||
if isinstance(v, (tuple, list, set)):
|
||||
return f"len={len(v)}"
|
||||
s = str(v)
|
||||
s = " ".join(s.split())
|
||||
if len(s) > 80:
|
||||
s = s[:77] + "..."
|
||||
return s
|
||||
|
||||
def _diff_config_summary(
|
||||
self, old: HubRuntimeConfig, new: HubRuntimeConfig
|
||||
) -> list[str]:
|
||||
old_d = asdict(old)
|
||||
new_d = asdict(new)
|
||||
old_d.pop("config_path", None)
|
||||
new_d.pop("config_path", None)
|
||||
|
||||
changed: list[str] = []
|
||||
for k in sorted(new_d.keys()):
|
||||
if old_d.get(k) == new_d.get(k):
|
||||
continue
|
||||
changed.append(
|
||||
f"{k}: {self._format_reload_value(old_d.get(k))} -> {self._format_reload_value(new_d.get(k))}"
|
||||
)
|
||||
return changed
|
||||
|
||||
def _ensure_worker_threads(self) -> None:
|
||||
# Announce loop
|
||||
if self._announce_thread is None or not self._announce_thread.is_alive():
|
||||
@@ -597,7 +511,7 @@ class HubService:
|
||||
room: str | None,
|
||||
outgoing: list[tuple[RNS.Link, bytes]] | None = None,
|
||||
) -> None:
|
||||
cfg_path = self._config_path_for_writes()
|
||||
cfg_path = self.config_manager.get_config_path_for_writes()
|
||||
if not cfg_path or not os.path.exists(cfg_path):
|
||||
self._emit_notice(
|
||||
outgoing, link, room, "reload failed: config_path not set or missing"
|
||||
@@ -612,8 +526,8 @@ class HubService:
|
||||
|
||||
# Stage config parse
|
||||
try:
|
||||
data = self._load_toml(cfg_path)
|
||||
new_cfg = self._apply_config_data(old_cfg, data)
|
||||
data = self.config_manager.load_toml(cfg_path)
|
||||
new_cfg = self.config_manager.apply_config_data(old_cfg, data)
|
||||
except Exception as e:
|
||||
self._emit_notice(
|
||||
outgoing, link, room, f"reload failed: config parse error: {e}"
|
||||
@@ -671,7 +585,7 @@ class HubService:
|
||||
except Exception:
|
||||
self.log.exception("Failed to reconfigure logging")
|
||||
|
||||
cfg_changes = self._diff_config_summary(old_cfg, new_cfg)
|
||||
cfg_changes = self.config_manager.diff_config_summary(old_cfg, new_cfg)
|
||||
room_changes = self.room_manager.diff_registry_summary(old_registry, new_registry)
|
||||
|
||||
lines: list[str] = []
|
||||
@@ -791,12 +705,6 @@ class HubService:
|
||||
for room in rooms_to_prune:
|
||||
self.log.info("Pruned unused registered room %s", room)
|
||||
|
||||
def _config_path_for_writes(self) -> str | None:
|
||||
p = self.config.config_path
|
||||
if not p:
|
||||
return None
|
||||
return expand_path(str(p))
|
||||
|
||||
def _notice_to(self, link: RNS.Link, room: str | None, text: str) -> None:
|
||||
if self.identity is None:
|
||||
return
|
||||
|
||||
+2
-2
@@ -110,7 +110,7 @@ class TrustManager:
|
||||
outgoing: list[tuple[RNS.Link, bytes]] | None = None,
|
||||
) -> None:
|
||||
"""Persist the current banned identities list to the config file."""
|
||||
cfg_path = self.hub._config_path_for_writes()
|
||||
cfg_path = self.hub.config_manager.get_config_path_for_writes()
|
||||
if not cfg_path:
|
||||
self.hub._emit_notice(
|
||||
outgoing, link, room, "ban updated (not persisted; no config_path)"
|
||||
@@ -129,7 +129,7 @@ class TrustManager:
|
||||
return
|
||||
|
||||
try:
|
||||
with self.hub._config_write_lock:
|
||||
with self.hub.config_manager.get_write_lock():
|
||||
st = None
|
||||
try:
|
||||
st = os.stat(cfg_path)
|
||||
|
||||
Reference in New Issue
Block a user