Files
Reticulum/tests/channel.py
T
Jeremy O'Brien 794e437f6d Channel: prevent sequence holes and ghost envelopes when sending on a dying outlet
RNSChannelOutlet.send() can return a packet that never reached the wire
(link not ACTIVE, no capable interface, etc). The old Channel.send()
queued the envelope in _tx_ring before calling outlet.send(), then
tried to rewind _next_sequence and remove the envelope if the outlet
returned a failed packet. Two problems:

- Between queueing and outlet.send() returning, _tx_ring held an
envelope with packet.raw=None. Any worker thread iterating the
ring (timeout fire, proof callback) crashed in get_packet_id's
packet.get_hash() with a TypeError on None.raw.

- The rewind was only safe for a single-threaded sender: it checked
"is _next_sequence one past mine?" and skipped the rewind otherwise.
Under concurrent senders, the rewind silently failed, leaving a
hole in the on-wire sequence stream. The receiver's contiguous
seqnum rule then stalled the channel permanently with no error.

This fix serializes the reservation-and-transmit pair with a per-channel
_send_lock so the rewind is always correct, and defers queueing until
outlet.send() returns a real packet so _tx_ring never contains a
packet-less envelope. _packet_tx_op() and get_packet_id() now also
defensively skip/return-None for packet-less envelopes.

Also handle the small race where a proof arrives between outlet.send()
registering the receipt and us installing the delivery callback: after
registration, re-read the receipt status and synthesize the
_packet_delivered() call if it's already DELIVERED.
2026-05-19 14:52:59 -04:00

550 lines
18 KiB
Python

from __future__ import annotations
import threading
import RNS
from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase
import RNS.Buffer
from RNS.vendor import umsgpack
from typing import Callable
import contextlib
import typing
import types
import time
import uuid
import unittest
class Packet:
timeout = 1.0
def __init__(self, raw: bytes):
self.state = MessageState.MSGSTATE_NEW
self.raw = raw
self.packet_id = uuid.uuid4()
self.tries = 0
self.timeout_id = None
self.lock = threading.RLock()
self.instances = 0
self.timeout_callback: Callable[[Packet], None] | None = None
self.delivered_callback: Callable[[Packet], None] | None = None
def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float):
with self.lock:
if timeout is not None:
self.timeout = timeout
self.timeout_callback = callback
def send(self):
self.tries += 1
self.state = MessageState.MSGSTATE_SENT
def elapsed(timeout: float, timeout_id: uuid.uuid4):
with self.lock:
self.instances += 1
try:
time.sleep(timeout)
with self.lock:
if self.timeout_id == timeout_id:
self.timeout_id = None
self.state = MessageState.MSGSTATE_FAILED
if self.timeout_callback:
self.timeout_callback(self)
finally:
with self.lock:
self.instances -= 1
self.timeout_id = uuid.uuid4()
threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id],
daemon=True).start()
def clear_timeout(self):
self.timeout_id = None
def set_delivered_callback(self, callback: Callable[[Packet], None]):
self.delivered_callback = callback
def delivered(self):
with self.lock:
self.state = MessageState.MSGSTATE_DELIVERED
self.timeout_id = None
if self.delivered_callback:
self.delivered_callback(self)
class ChannelOutletTest(ChannelOutletBase):
def get_packet_state(self, packet: Packet) -> MessageState:
return packet.state
def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None,
timeout: float | None = None):
packet.set_timeout(callback, timeout)
def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None):
packet.set_delivered_callback(callback)
def get_packet_id(self, packet: Packet) -> any:
return packet.packet_id
def __init__(self, mdu: int, rtt: float):
self.link_id = uuid.uuid4()
self.timeout_callbacks = 0
self._mdu = mdu
self._rtt = rtt
self._usable = True
self.packets = []
self.lock = threading.RLock()
self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None
def send(self, raw: bytes) -> Packet:
with self.lock:
packet = Packet(raw)
packet.send()
self.packets.append(packet)
return packet
def resend(self, packet: Packet) -> Packet:
with self.lock:
packet.send()
return packet
@property
def mdu(self):
return self._mdu
@property
def rtt(self):
return self._rtt
@property
def is_usable(self):
return self._usable
def timed_out(self):
self.timeout_callbacks += 1
def __str__(self):
return str(self.link_id)
class MessageTest(MessageBase):
MSGTYPE = 0xabcd
def __init__(self):
self.id = str(uuid.uuid4())
self.data = "test"
self.not_serialized = str(uuid.uuid4())
def pack(self) -> bytes:
return umsgpack.packb((self.id, self.data))
def unpack(self, raw):
self.id, self.data = umsgpack.unpackb(raw)
class SystemMessage(MessageBase):
MSGTYPE = 0xf000
def pack(self) -> bytes:
return bytes()
def unpack(self, raw):
pass
class ProtocolHarness(contextlib.AbstractContextManager):
def __init__(self, rtt: float):
self.outlet = ChannelOutletTest(mdu=500, rtt=rtt)
self.channel = Channel(self.outlet)
Packet.timeout = self.channel._get_packet_timeout_time(1)
def cleanup(self):
self.channel._shutdown()
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
# self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})")
self.cleanup()
return False
class TestChannel(unittest.TestCase):
def setUp(self) -> None:
print("")
self.rtt = 0.01
self.h = ProtocolHarness(self.rtt)
def tearDown(self) -> None:
self.h.cleanup()
def test_send_one_retry(self):
print("Channel test one retry")
message = MessageTest()
self.assertEqual(0, len(self.h.outlet.packets))
envelope = self.h.channel.send(message)
self.assertIsNotNone(envelope)
self.assertIsNotNone(envelope.raw)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertIsNotNone(envelope.packet)
self.assertTrue(envelope in self.h.channel._tx_ring)
self.assertTrue(envelope.tracked)
packet = self.h.outlet.packets[0]
self.assertEqual(envelope.packet, packet)
self.assertEqual(1, envelope.tries)
self.assertEqual(1, packet.tries)
self.assertEqual(1, packet.instances)
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
self.assertEqual(envelope.raw, packet.raw)
time.sleep(self.h.channel._get_packet_timeout_time(1) * 1.1)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(2, envelope.tries)
self.assertEqual(2, packet.tries)
self.assertEqual(1, packet.instances)
time.sleep(self.h.channel._get_packet_timeout_time(2) * 1.1)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(self.h.outlet.packets[0], packet)
self.assertEqual(3, envelope.tries)
self.assertEqual(3, packet.tries)
self.assertEqual(1, packet.instances)
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
packet.delivered()
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
time.sleep(self.h.channel._get_packet_timeout_time(3) * 1.1)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(3, envelope.tries)
self.assertEqual(3, packet.tries)
self.assertEqual(0, packet.instances)
self.assertFalse(envelope.tracked)
def test_send_timeout(self):
print("Channel test retry count exceeded")
message = MessageTest()
self.assertEqual(0, len(self.h.outlet.packets))
envelope = self.h.channel.send(message)
self.assertIsNotNone(envelope)
self.assertIsNotNone(envelope.raw)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertIsNotNone(envelope.packet)
self.assertTrue(envelope in self.h.channel._tx_ring)
self.assertTrue(envelope.tracked)
packet = self.h.outlet.packets[0]
self.assertEqual(envelope.packet, packet)
self.assertEqual(1, envelope.tries)
self.assertEqual(1, packet.tries)
self.assertEqual(1, packet.instances)
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
self.assertEqual(envelope.raw, packet.raw)
time.sleep(self.h.channel._get_packet_timeout_time(1))
time.sleep(self.h.channel._get_packet_timeout_time(2))
time.sleep(self.h.channel._get_packet_timeout_time(3))
time.sleep(self.h.channel._get_packet_timeout_time(4))
time.sleep(self.h.channel._get_packet_timeout_time(5) * 1.1)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(5, envelope.tries)
self.assertEqual(5, packet.tries)
self.assertEqual(0, packet.instances)
self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state)
self.assertFalse(envelope.tracked)
def test_send_on_failing_outlet_does_not_corrupt_state(self):
# if outlet.send() returns a packet that never reached
# the wire (LinkChannelOutlet does this when the link is not ACTIVE; the
# returned packet has raw=None), Channel.send() must not consume a
# sequence number or leave a packetless envelope in _tx_ring. Before
# the fix, the envelope was queued before outlet.send() returned, so a
# "dead" return left a raw=None envelope in the ring and silently
# advanced _next_sequence, stalling the channel on the other end.
print("Channel test send on failing outlet")
original_send = self.h.outlet.send
def ghost_send(raw):
with self.h.outlet.lock:
packet = Packet(None)
packet.state = MessageState.MSGSTATE_FAILED
self.h.outlet.packets.append(packet)
return packet
self.h.outlet.send = ghost_send
pre_sequence = self.h.channel._next_sequence
self.assertEqual(0, len(self.h.channel._tx_ring))
with self.assertRaises(RNS.Channel.ChannelException):
self.h.channel.send(MessageTest())
# Sequence must not have been consumed.
self.assertEqual(pre_sequence, self.h.channel._next_sequence)
# _tx_ring must not contain a packetless envelope.
self.assertEqual(0, len(self.h.channel._tx_ring))
# A subsequent successful send should use the same sequence number as
# was reserved for the failed attempt.
self.h.outlet.send = original_send
envelope = self.h.channel.send(MessageTest())
self.assertEqual(pre_sequence, envelope.sequence)
self.assertIsNotNone(envelope.packet)
self.assertIsNotNone(envelope.packet.raw)
self.assertTrue(envelope in self.h.channel._tx_ring)
def test_multiple_handler(self):
print("Channel test multiple handler short circuit")
handler1_called = 0
handler1_return = True
handler2_called = 0
def handler1(msg: MessageBase):
nonlocal handler1_called, handler1_return
self.assertIsInstance(msg, MessageTest)
handler1_called += 1
return handler1_return
def handler2(msg: MessageBase):
nonlocal handler2_called
self.assertIsInstance(msg, MessageTest)
handler2_called += 1
message = MessageTest()
self.h.channel.register_message_type(MessageTest)
self.h.channel.add_message_handler(handler1)
self.h.channel.add_message_handler(handler2)
envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0)
raw = envelope.pack()
self.h.channel._receive(raw)
time.sleep(0.5)
self.assertEqual(1, handler1_called)
self.assertEqual(0, handler2_called)
handler1_return = False
envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1)
raw = envelope.pack()
self.h.channel._receive(raw)
time.sleep(0.5)
self.assertEqual(2, handler1_called)
self.assertEqual(1, handler2_called)
def test_system_message_check(self):
print("Channel test register system message")
with self.assertRaises(RNS.Channel.ChannelException):
self.h.channel.register_message_type(SystemMessage)
self.h.channel._register_message_type(SystemMessage, is_system_type=True)
def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]):
decoded: [MessageBase] = []
def handle_message(message: MessageBase):
decoded.append(message)
self.h.channel.register_message_type(message.__class__)
self.h.channel.add_message_handler(handle_message)
self.assertEqual(len(self.h.outlet.packets), 0)
envelope = self.h.channel.send(message)
time.sleep(self.h.channel._get_packet_timeout_time(1) * 0.5)
self.assertIsNotNone(envelope)
self.assertIsNotNone(envelope.raw)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertIsNotNone(envelope.packet)
self.assertTrue(envelope in self.h.channel._tx_ring)
self.assertTrue(envelope.tracked)
packet = self.h.outlet.packets[0]
self.assertEqual(envelope.packet, packet)
self.assertEqual(1, envelope.tries)
self.assertEqual(1, packet.tries)
self.assertEqual(1, packet.instances)
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
self.assertEqual(envelope.raw, packet.raw)
packet.delivered()
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
time.sleep(self.h.channel._get_packet_timeout_time(1))
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(1, envelope.tries)
self.assertEqual(1, packet.tries)
self.assertEqual(0, packet.instances)
self.assertFalse(envelope.tracked)
self.assertEqual(len(self.h.outlet.packets), 1)
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
self.assertFalse(envelope.tracked)
self.assertEqual(0, len(decoded))
self.h.channel._receive(packet.raw)
time.sleep(0.5)
self.assertEqual(1, len(decoded))
rx_message = decoded[0]
self.assertIsNotNone(rx_message)
self.assertIsInstance(rx_message, message.__class__)
checker(rx_message)
def test_send_receive_message_test(self):
print("Channel test send and receive message")
message = MessageTest()
def check(rx_message: MessageBase):
self.assertIsInstance(rx_message, message.__class__)
self.assertEqual(message.id, rx_message.id)
self.assertEqual(message.data, rx_message.data)
self.assertNotEqual(message.not_serialized, rx_message.not_serialized)
self.eat_own_dog_food(message, check)
def test_buffer_small_bidirectional(self):
data = "Hello\n"
with RNS.Buffer.create_bidirectional_buffer(0, 0, self.h.channel) as buffer:
count = buffer.write(data.encode("utf-8"))
buffer.flush()
self.assertEqual(len(data), count)
self.assertEqual(1, len(self.h.outlet.packets))
packet = self.h.outlet.packets[0]
self.h.channel._receive(packet.raw)
time.sleep(0.2)
result = buffer.readline()
self.assertIsNotNone(result)
self.assertEqual(len(result), len(data))
decoded = result.decode("utf-8")
self.assertEqual(data, decoded)
def test_buffer_big(self):
writer = RNS.Buffer.create_writer(15, self.h.channel)
reader = RNS.Buffer.create_reader(15, self.h.channel)
data = "01234556789"*1024*5 # 50 KB
count = 0
write_finished = False
def write_thread():
nonlocal count, write_finished
count = writer.write(data.encode("utf-8"))
writer.flush()
writer.close() # TODO: Workaround for https://github.com/python/cpython/issues/138720
write_finished = True
threading.Thread(target=write_thread, name="Write Thread", daemon=True).start()
while not write_finished or next(filter(lambda x: x.state != MessageState.MSGSTATE_DELIVERED,
self.h.outlet.packets), None) is not None:
with self.h.outlet.lock:
for packet in self.h.outlet.packets:
if packet.state != MessageState.MSGSTATE_DELIVERED:
self.h.channel._receive(packet.raw)
packet.delivered()
time.sleep(0.0001)
self.assertEqual(len(data), count)
read_finished = False
result = bytes()
def read_thread():
nonlocal read_finished, result
result = reader.read()
read_finished = True
threading.Thread(target=read_thread, name="Read Thread", daemon=True).start()
timeout_at = time.time() + 7
while not read_finished and time.time() < timeout_at:
time.sleep(0.001)
self.assertTrue(read_finished)
self.assertEqual(len(data), len(result))
decoded = result.decode("utf-8")
self.assertSequenceEqual(data, decoded)
def test_buffer_small_with_callback(self):
callbacks = 0
last_cb_value = None
def callback(ready: int):
nonlocal callbacks, last_cb_value
callbacks += 1
last_cb_value = ready
data = "Hello\n"
with RNS.RawChannelWriter(0, self.h.channel) as writer, RNS.RawChannelReader(0, self.h.channel) as reader:
reader.add_ready_callback(callback)
count = writer.write(data.encode("utf-8"))
writer.flush()
self.assertEqual(len(data), count)
self.assertEqual(1, len(self.h.outlet.packets))
packet = self.h.outlet.packets[0]
self.h.channel._receive(packet.raw)
packet.delivered()
self.assertEqual(1, callbacks)
self.assertEqual(len(data), last_cb_value)
result = reader.readline()
self.assertIsNotNone(result)
self.assertEqual(len(result), len(data))
decoded = result.decode("utf-8")
self.assertEqual(data, decoded)
self.assertEqual(1, len(self.h.outlet.packets))
result = reader.read(1)
self.assertIsNone(result)
self.assertTrue(self.h.channel.is_ready_to_send())
writer.close()
self.assertEqual(2, len(self.h.outlet.packets))
packet = self.h.outlet.packets[1]
self.h.channel._receive(packet.raw)
packet.delivered()
result = reader.read(1)
self.assertIsNotNone(result)
self.assertTrue(len(result) == 0)
if __name__ == '__main__':
unittest.main(verbosity=2)