From 794e437f6dbe25ab6b9ea35cbe11f60fac288fc8 Mon Sep 17 00:00:00 2001 From: Jeremy O'Brien Date: Tue, 19 May 2026 09:23:48 -0400 Subject: [PATCH] Channel: prevent sequence holes and ghost envelopes when sending on a dying outlet RNSChannelOutlet.send() can return a packet that never reached the wire (link not ACTIVE, no capable interface, etc). The old Channel.send() queued the envelope in _tx_ring before calling outlet.send(), then tried to rewind _next_sequence and remove the envelope if the outlet returned a failed packet. Two problems: - Between queueing and outlet.send() returning, _tx_ring held an envelope with packet.raw=None. Any worker thread iterating the ring (timeout fire, proof callback) crashed in get_packet_id's packet.get_hash() with a TypeError on None.raw. - The rewind was only safe for a single-threaded sender: it checked "is _next_sequence one past mine?" and skipped the rewind otherwise. Under concurrent senders, the rewind silently failed, leaving a hole in the on-wire sequence stream. The receiver's contiguous seqnum rule then stalled the channel permanently with no error. This fix serializes the reservation-and-transmit pair with a per-channel _send_lock so the rewind is always correct, and defers queueing until outlet.send() returns a real packet so _tx_ring never contains a packet-less envelope. _packet_tx_op() and get_packet_id() now also defensively skip/return-None for packet-less envelopes. Also handle the small race where a proof arrives between outlet.send() registering the receipt and us installing the delivery callback: after registration, re-read the receipt status and synthesize the _packet_delivered() call if it's already DELIVERED. --- RNS/Channel.py | 145 +++++++++++++++++++++++++++++------------------ tests/channel.py | 43 +++++++++++++- 2 files changed, 131 insertions(+), 57 deletions(-) diff --git a/RNS/Channel.py b/RNS/Channel.py index 011fde17..3072b852 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -144,7 +144,7 @@ class MessageBase(abc.ABC): MSGTYPE = None """ Defines a unique identifier for a message class. - + * Must be unique within all classes registered with a ``Channel`` * Must be less than ``0xf000``. Values greater than or equal to ``0xf000`` are reserved. """ @@ -255,11 +255,11 @@ class Channel(contextlib.AbstractContextManager): # The maximum window size for transfers on fast links WINDOW_MAX_FAST = 48 - + # For calculating maps and guard segments, this # must be set to the global maximum window. WINDOW_MAX = WINDOW_MAX_FAST - + # If the fast rate is sustained for this many request # rounds, the fast link window size will be allowed. FAST_RATE_THRESHOLD = 10 @@ -285,6 +285,7 @@ class Channel(contextlib.AbstractContextManager): """ self._outlet = outlet self._lock = threading.RLock() + self._send_lock = threading.Lock() self._tx_ring: collections.deque[Envelope] = collections.deque() self._rx_ring: collections.deque[Envelope] = collections.deque() self._message_callbacks: [MessageCallbackType] = [] @@ -382,27 +383,30 @@ class Channel(contextlib.AbstractContextManager): if envelope.packet is not None: self._outlet.set_packet_timeout_callback(envelope.packet, None) self._outlet.set_packet_delivered_callback(envelope.packet, None) + envelope.tracked = False + for envelope in self._rx_ring: + envelope.tracked = False self._tx_ring.clear() self._rx_ring.clear() def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: with self._lock: i = 0 - + for existing in ring: if envelope.sequence == existing.sequence: RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME) return False - + if envelope.sequence < existing.sequence and not (self._next_rx_sequence - envelope.sequence) > (Channel.SEQ_MAX//2): ring.insert(i, envelope) envelope.tracked = True return True - + i += 1 - + envelope.tracked = True ring.append(envelope) @@ -457,7 +461,7 @@ class Channel(contextlib.AbstractContextManager): m = e.unpack(self._message_factories) else: m = e.message - + self._rx_ring.remove(e) self._run_callbacks(m) @@ -476,7 +480,7 @@ class Channel(contextlib.AbstractContextManager): with self._lock: outstanding = 0 for envelope in self._tx_ring: - if envelope.outlet == self._outlet: + if envelope.outlet == self._outlet: if not envelope.packet or not self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_DELIVERED: outstanding += 1 @@ -486,8 +490,10 @@ class Channel(contextlib.AbstractContextManager): return True def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]): + target_id = self._outlet.get_packet_id(packet) with self._lock: - envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet), + envelope = next(filter(lambda e: e.packet is not None + and self._outlet.get_packet_id(e.packet) == target_id, self._tx_ring), None) if envelope and op(envelope): @@ -516,7 +522,7 @@ class Channel(contextlib.AbstractContextManager): # TODO: Remove at some point # RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_DEBUG) # RNS.log("Increased "+str(self)+" min window to "+str(self.window_min), RNS.LOG_DEBUG) - + else: self.fast_rate_rounds += 1 if self.window_max < Channel.WINDOW_MAX_FAST and self.fast_rate_rounds == Channel.FAST_RATE_THRESHOLD: @@ -547,36 +553,48 @@ class Channel(contextlib.AbstractContextManager): return to def _packet_timeout(self, packet: TPacket): - def retry_envelope(envelope: Envelope) -> bool: + if self._outlet.get_packet_state(packet) == MessageState.MSGSTATE_DELIVERED: + return + + target_id = self._outlet.get_packet_id(packet) + envelope_to_resend: Envelope | None = None + should_teardown = False + with self._lock: + envelope = next(filter( + lambda e: e.packet is not None and self._outlet.get_packet_id(e.packet) == target_id, + self._tx_ring), None) + if envelope is None: + return + if envelope.tries >= self._max_tries: - RNS.log("Retry count exceeded on "+str(self)+", tearing down Link.", RNS.LOG_ERROR) - self._shutdown() # start on separate thread? - self._outlet.timed_out() - return True + should_teardown = True + else: + envelope.tries += 1 + envelope_to_resend = envelope - envelope.tries += 1 - self._outlet.resend(envelope.packet) - self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) - self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries)) - self._update_packet_timeouts() + if self.window > self.window_min: + self.window -= 1 + if self.window_max > (self.window_min+self.window_flexibility): + self.window_max -= 1 - if self.window > self.window_min: - self.window -= 1 - # TODO: Remove at some point - # RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_DEBUG) + if should_teardown: + RNS.log("Retry count exceeded on "+str(self)+", tearing down Link.", RNS.LOG_ERROR) + self._shutdown() + self._outlet.timed_out() + return - if self.window_max > (self.window_min+self.window_flexibility): - self.window_max -= 1 - # TODO: Remove at some point - # RNS.log("Decreased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_DEBUG) + if envelope_to_resend is not None: + self._outlet.resend(envelope_to_resend.packet) + with self._lock: + self._outlet.set_packet_delivered_callback(envelope_to_resend.packet, self._packet_delivered) + self._outlet.set_packet_timeout_callback( + envelope_to_resend.packet, self._packet_timeout, + self._get_packet_timeout_time(envelope_to_resend.tries)) + self._update_packet_timeouts() + already_delivered = (self._outlet.get_packet_state(envelope_to_resend.packet) == MessageState.MSGSTATE_DELIVERED) - # TODO: Remove at some point - # RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME) - - return False - - if self._outlet.get_packet_state(packet) != MessageState.MSGSTATE_DELIVERED: - self._packet_tx_op(packet, retry_envelope) + if already_delivered: + self._packet_delivered(envelope_to_resend.packet) def send(self, message: MessageBase) -> Envelope: """ @@ -585,27 +603,39 @@ class Channel(contextlib.AbstractContextManager): :param message: an instance of a ``MessageBase`` subclass """ - envelope: Envelope | None = None - with self._lock: - if not self.is_ready_to_send(): - raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready") - - envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence) - self._next_sequence = (self._next_sequence + 1) % Channel.SEQ_MODULUS - self._emplace_envelope(envelope, self._tx_ring) + with self._send_lock: + with self._lock: + if not self.is_ready_to_send(): + raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready") - if envelope is None: - raise BlockingIOError() + reserved_sequence = self._next_sequence + envelope = Envelope(self._outlet, message=message, sequence=reserved_sequence) + envelope.pack() + if len(envelope.raw) > self._outlet.mdu: + raise ChannelException(CEType.ME_TOO_BIG, + f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}") + self._next_sequence = (reserved_sequence + 1) % Channel.SEQ_MODULUS - envelope.pack() - if len(envelope.raw) > self._outlet.mdu: - raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}") - - envelope.packet = self._outlet.send(envelope.raw) - envelope.tries += 1 - self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) - self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries)) - self._update_packet_timeouts() + envelope.packet = self._outlet.send(envelope.raw) + + if (envelope.packet is None + or getattr(envelope.packet, "raw", None) is None + or (hasattr(envelope.packet, "receipt") and envelope.packet.receipt is None)): + with self._lock: + self._next_sequence = reserved_sequence + raise ChannelException(CEType.ME_LINK_NOT_READY, "Outlet did not transmit packet") + + with self._lock: + self._emplace_envelope(envelope, self._tx_ring) + envelope.tries += 1 + self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) + self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries)) + self._update_packet_timeouts() + already_delivered = (self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_DELIVERED) + + # prevent _tx_ring envelope leak + if already_delivered: + self._packet_delivered(envelope.packet) return envelope @@ -699,7 +729,10 @@ class LinkChannelOutlet(ChannelOutletBase): packet.receipt.set_delivery_callback(inner if callback else None) def get_packet_id(self, packet: RNS.Packet) -> any: - if packet and hasattr(packet, "get_hash") and callable(packet.get_hash): + if (packet + and getattr(packet, "raw", None) is not None + and hasattr(packet, "get_hash") + and callable(packet.get_hash)): return packet.get_hash() else: return None diff --git a/tests/channel.py b/tests/channel.py index ab6a5948..89946dc6 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -62,7 +62,7 @@ class Packet: def set_delivered_callback(self, callback: Callable[[Packet], None]): self.delivered_callback = callback - + def delivered(self): with self.lock: self.state = MessageState.MSGSTATE_DELIVERED @@ -265,6 +265,47 @@ class TestChannel(unittest.TestCase): self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) self.assertFalse(envelope.tracked) + def test_send_on_failing_outlet_does_not_corrupt_state(self): + # if outlet.send() returns a packet that never reached + # the wire (LinkChannelOutlet does this when the link is not ACTIVE; the + # returned packet has raw=None), Channel.send() must not consume a + # sequence number or leave a packetless envelope in _tx_ring. Before + # the fix, the envelope was queued before outlet.send() returned, so a + # "dead" return left a raw=None envelope in the ring and silently + # advanced _next_sequence, stalling the channel on the other end. + print("Channel test send on failing outlet") + + original_send = self.h.outlet.send + + def ghost_send(raw): + with self.h.outlet.lock: + packet = Packet(None) + packet.state = MessageState.MSGSTATE_FAILED + self.h.outlet.packets.append(packet) + return packet + + self.h.outlet.send = ghost_send + + pre_sequence = self.h.channel._next_sequence + self.assertEqual(0, len(self.h.channel._tx_ring)) + + with self.assertRaises(RNS.Channel.ChannelException): + self.h.channel.send(MessageTest()) + + # Sequence must not have been consumed. + self.assertEqual(pre_sequence, self.h.channel._next_sequence) + # _tx_ring must not contain a packetless envelope. + self.assertEqual(0, len(self.h.channel._tx_ring)) + + # A subsequent successful send should use the same sequence number as + # was reserved for the failed attempt. + self.h.outlet.send = original_send + envelope = self.h.channel.send(MessageTest()) + self.assertEqual(pre_sequence, envelope.sequence) + self.assertIsNotNone(envelope.packet) + self.assertIsNotNone(envelope.packet.raw) + self.assertTrue(envelope in self.h.channel._tx_ring) + def test_multiple_handler(self): print("Channel test multiple handler short circuit")