Files
2026-04-27 00:12:27 +02:00

441 lines
18 KiB
Python

# Based on the original rnsh program by Aaron Heise (@acehoss)
# https://github.com/acehoss/rnsh - MIT License - Copyright (c) 2023 Aaron Heise
# This version of rnsh is included in RNS under the Reticulum License
#
# Reticulum License
#
# Copyright (c) 2016-2026 Mark Qvist
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# - The Software shall not be used in any kind of system which includes amongst
# its functions the ability to purposefully do harm to human beings.
#
# - The Software shall not be used, directly or indirectly, in the creation of
# an artificial intelligence, machine learning or language model training
# dataset, including but not limited to any use that contributes to the
# training or development of such a model or algorithm.
#
# - The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations
import contextlib
import functools
import asyncio
import RNS.Utilities.rnsh.exception as exception
import RNS.Utilities.rnsh.process as process
import RNS.Utilities.rnsh.helpers as helpers
import RNS.Utilities.rnsh.protocol as protocol
import enum
from typing import TypeVar, Generic, Callable, List
from abc import abstractmethod, ABC
from multiprocessing import Manager
import os
import bz2
import RNS
_TLink = TypeVar("_TLink")
_TIdentity = TypeVar("_TIdentity")
class SEType(enum.IntEnum):
SE_LINK_CLOSED = 0
class SessionException(Exception):
def __init__(self, setype: SEType, msg: str, *args):
super().__init__(msg, args)
self.type = setype
class LSState(enum.IntEnum):
LSSTATE_WAIT_IDENT = 1
LSSTATE_WAIT_VERS = 2
LSSTATE_WAIT_CMD = 3
LSSTATE_RUNNING = 4
LSSTATE_ERROR = 5
LSSTATE_TEARDOWN = 6
class LSOutletBase(ABC):
@abstractmethod
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): raise NotImplemented()
@abstractmethod
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]): raise NotImplemented()
@abstractmethod
def unset_link_closed_callback(self): raise NotImplemented()
@property
@abstractmethod
def rtt(self): raise NotImplemented()
@abstractmethod
def teardown(self): raise NotImplemented()
class ListenerSession:
sessions: List[ListenerSession] = []
allowed_identity_hashes: [any] = []
allowed_file_identity_hashes: [any] = []
allow_all: bool = False
allow_remote_command: bool = False
default_command: [str] = []
remote_cmd_as_args = False
def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
RNS.log(f"Session started for {outlet}", RNS.LOG_INFO)
self.outlet = outlet
self.channel = channel
self.outlet.set_initiator_identified_callback(self._initiator_identified)
self.outlet.set_link_closed_callback(self._link_closed)
self.loop = loop
self.state: LSState = None
self.remote_identity = None
self.term: str | None = None
self.stdin_is_pipe: bool = False
self.stdout_is_pipe: bool = False
self.stderr_is_pipe: bool = False
self.tcflags: [any] = None
self.cmdline: [str] = None
self.rows: int = 0
self.cols: int = 0
self.hpix: int = 0
self.vpix: int = 0
self.stdout_buf = bytearray()
self.stdout_eof_sent = False
self.stderr_buf = bytearray()
self.stderr_eof_sent = False
self.return_code: int | None = None
self.return_code_sent = False
self.process: process.CallbackSubprocess | None = None
if self.allow_all: self._set_state(LSState.LSSTATE_WAIT_VERS)
else: self._set_state(LSState.LSSTATE_WAIT_IDENT)
self.sessions.append(self)
protocol.register_message_types(self.channel)
self.channel.add_message_handler(self._handle_message)
def _terminated(self, return_code: int):
self.return_code = return_code
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
RNS.log(f"Set state: {state.name}, timeout {timeout}", RNS.LOG_DEBUG)
orig_state = self.state
self.state = state
if timeout_factor is not None:
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
def _call(self, func: callable, delay: float = 0):
def call_inner():
if delay == 0: func()
else: self.loop.call_later(delay, func)
self.loop.call_soon_threadsafe(call_inner)
def send(self, message: RNS.MessageBase):
self.channel.send(message)
def _protocol_error(self, name: str):
self.terminate(f"Protocol error ({name})")
def _protocol_timeout_error(self, name: str):
self.terminate(f"Protocol timeout error: {name}")
def terminate(self, error: str = None):
with contextlib.suppress(Exception):
RNS.log("Terminating session" + (f": {error}" if error else ""), RNS.LOG_DEBUG)
if error and self.state != LSState.LSSTATE_TEARDOWN:
with contextlib.suppress(Exception):
self.send(protocol.ErrorMessage(error, True))
self.state = LSState.LSSTATE_ERROR
self._terminate_process()
self._call(self._prune, max(self.outlet.rtt * 3, process.CallbackSubprocess.PROCESS_PIPE_TIME+5))
def _prune(self):
self.state = LSState.LSSTATE_TEARDOWN
RNS.log("Pruning session", RNS.LOG_DEBUG)
with contextlib.suppress(ValueError):
self.sessions.remove(self)
with contextlib.suppress(Exception):
self.outlet.teardown()
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
timeout = True
try: timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
except Exception as e: RNS.log(f"Error in protocol timeout: {e}", RNS.LOG_ERROR)
if timeout: self._protocol_timeout_error(name)
def _link_closed(self, outlet: LSOutletBase):
outlet.unset_link_closed_callback()
if outlet != self.outlet:
RNS.log("Link closed received from incorrect outlet", RNS.LOG_DEBUG)
return
RNS.log(f"link_closed {outlet}", RNS.LOG_DEBUG)
self.terminate()
def _initiator_identified(self, outlet, identity):
if outlet != self.outlet:
RNS.log("Identity received from incorrect outlet", RNS.LOG_DEBUG)
return
RNS.log(f"initiator_identified {identity} on link {outlet}", RNS.LOG_INFO)
if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]:
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
if not self.allow_all and identity.hash not in self.allowed_identity_hashes and identity.hash not in self.allowed_file_identity_hashes:
self.terminate("Identity is not allowed.")
self.remote_identity = identity
self._set_state(LSState.LSSTATE_WAIT_VERS)
@classmethod
async def pump_all(cls) -> True:
processed_any = False
for session in cls.sessions:
processed = session.pump()
processed_any = processed_any or processed
await asyncio.sleep(0)
@classmethod
async def terminate_all(cls, reason: str):
for session in cls.sessions:
session.terminate(reason)
await asyncio.sleep(0)
def pump(self) -> bool:
def compress_adaptive(buf: bytes):
comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES
comp_try = 1
comp_success = False
chunk_len = len(buf)
if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN:
chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN
chunk_segment = None
chunk_segment = None
max_data_len = self.channel.mdu - protocol.StreamDataMessage.OVERHEAD
while chunk_len > 32 and comp_try < comp_tries:
chunk_segment_length = int(chunk_len/comp_try)
compressed_chunk = bz2.compress(buf[:chunk_segment_length])
compressed_length = len(compressed_chunk)
if compressed_length < max_data_len and compressed_length < chunk_segment_length:
comp_success = True
break
else:
comp_try += 1
if comp_success:
diff = max_data_len - len(compressed_chunk)
chunk = compressed_chunk
processed_length = chunk_segment_length
else:
chunk = bytes(buf[:max_data_len])
processed_length = len(chunk)
return comp_success, processed_length, chunk
try:
if self.state != LSState.LSSTATE_RUNNING:
return False
elif not self.channel.is_ready_to_send():
return False
elif len(self.stderr_buf) > 0:
comp_success, processed_length, data = compress_adaptive(self.stderr_buf)
self.stderr_buf = self.stderr_buf[processed_length:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
data, send_eof, comp_success)
self.send(msg)
if send_eof:
self.stderr_eof_sent = True
return True
elif len(self.stdout_buf) > 0:
comp_success, processed_length, data = compress_adaptive(self.stdout_buf)
self.stdout_buf = self.stdout_buf[processed_length:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
data, send_eof, comp_success)
self.send(msg)
if send_eof:
self.stdout_eof_sent = True
return True
elif self.return_code is not None and not self.return_code_sent:
msg = protocol.CommandExitedMessage(self.return_code)
self.send(msg)
self.return_code_sent = True
self._call(functools.partial(self._check_protocol_timeout,
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
max(self.outlet.rtt * 5, 10))
return False
except Exception as e: RNS.log(f"Error during pump: {e}", RNS.LOG_ERROR)
return False
def _terminate_process(self):
with contextlib.suppress(Exception):
if self.process and self.process.running:
self.process.terminate()
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
self.cmdline = self.default_command
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
self.terminate("Remote command line not allowed by listener")
return
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
self.cmdline.extend(cmdline)
elif cmdline and len(cmdline) > 0:
self.cmdline = cmdline
self.stdin_is_pipe = pipe_stdin
self.stdout_is_pipe = pipe_stdout
self.stderr_is_pipe = pipe_stderr
self.tcflags = tcflags
self.term = term
def stdout(data: bytes):
self.stdout_buf.extend(data)
def stderr(data: bytes):
self.stderr_buf.extend(data)
try:
self.process = process.CallbackSubprocess(argv=self.cmdline,
env={"TERM": self.term or os.environ.get("TERM") or "xterm",
"RNS_REMOTE_IDENTITY": (RNS.prettyhexrep(self.remote_identity.hash)
if self.remote_identity and self.remote_identity.hash else "")},
loop=self.loop,
stdout_callback=stdout,
stderr_callback=stderr,
terminated_callback=self._terminated,
stdin_is_pipe=self.stdin_is_pipe,
stdout_is_pipe=self.stdout_is_pipe,
stderr_is_pipe=self.stderr_is_pipe)
self.process.start()
self._set_window_size(rows, cols, hpix, vpix)
except Exception as e:
RNS.log(f"Unable to start process for link {self.outlet}: {e}", RNS.LOG_ERROR)
self.terminate("Unable to start process")
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
with contextlib.suppress(Exception):
self.process.set_winsize(rows, cols, hpix, vpix)
def _received_stdin(self, data: bytes, eof: bool):
if data and len(data) > 0:
self.process.write(data)
if eof:
self.process.close_stdin()
def _handle_message(self, message: RNS.MessageBase):
if self.state == LSState.LSSTATE_WAIT_IDENT:
# Ignore any messages until the initiator has identified to avoid race conditions
# between identity announcement and early protocol messages.
RNS.log("Ignoring message while waiting for identification", RNS.LOG_DEBUG)
return
if self.state == LSState.LSSTATE_WAIT_VERS:
if not isinstance(message, protocol.VersionInfoMessage):
self._protocol_error(self.state.name)
return
RNS.log(f"Version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}", RNS.LOG_VERBOSE)
if message.protocol_version != protocol.PROTOCOL_VERSION:
self.terminate("Incompatible protocol")
return
self.send(protocol.VersionInfoMessage())
self._set_state(LSState.LSSTATE_WAIT_CMD)
return
elif self.state == LSState.LSSTATE_WAIT_CMD:
if not isinstance(message, protocol.ExecuteCommandMesssage):
return self._protocol_error(self.state.name)
RNS.log(f"Execute command message on link {self.outlet}: {message.cmdline}", RNS.LOG_VERBOSE)
self._set_state(LSState.LSSTATE_RUNNING)
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
return
elif self.state == LSState.LSSTATE_RUNNING:
if isinstance(message, protocol.WindowSizeMessage):
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
elif isinstance(message, protocol.StreamDataMessage):
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
RNS.log(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}", RNS.LOG_ERROR)
return self._protocol_error(self.state.name)
self._received_stdin(message.data, message.eof)
return
elif isinstance(message, protocol.NoopMessage):
# echo noop only on listener--used for keepalive/connectivity check
self.send(message)
return
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
RNS.log(f"Received packet, but in state {self.state.name}", RNS.LOG_ERROR)
return
else:
self._protocol_error("unexpected message")
return
class RNSOutlet(LSOutletBase):
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
def inner_cb(link, identity: _TIdentity):
cb(self, identity)
self.link.set_remote_identified_callback(inner_cb)
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
def inner_cb(link):
cb(self)
self.link.set_link_closed_callback(inner_cb)
def unset_link_closed_callback(self):
self.link.set_link_closed_callback(None)
def teardown(self):
self.link.teardown()
@property
def rtt(self) -> float:
return self.link.rtt
def __str__(self):
return f"Outlet RNS Link {self.link}"
def __init__(self, link: RNS.Link):
self.link = link
link.lsoutlet = self
@staticmethod
def get_outlet(link: RNS.Link):
if hasattr(link, "lsoutlet"):
return link.lsoutlet
return RNSOutlet(link)