Channel: prevent sequence holes and ghost envelopes when sending on a dying outlet

RNSChannelOutlet.send() can return a packet that never reached the wire
(link not ACTIVE, no capable interface, etc). The old Channel.send()
queued the envelope in _tx_ring before calling outlet.send(), then
tried to rewind _next_sequence and remove the envelope if the outlet
returned a failed packet. Two problems:

- Between queueing and outlet.send() returning, _tx_ring held an
envelope with packet.raw=None. Any worker thread iterating the
ring (timeout fire, proof callback) crashed in get_packet_id's
packet.get_hash() with a TypeError on None.raw.

- The rewind was only safe for a single-threaded sender: it checked
"is _next_sequence one past mine?" and skipped the rewind otherwise.
Under concurrent senders, the rewind silently failed, leaving a
hole in the on-wire sequence stream. The receiver's contiguous
seqnum rule then stalled the channel permanently with no error.

This fix serializes the reservation-and-transmit pair with a per-channel
_send_lock so the rewind is always correct, and defers queueing until
outlet.send() returns a real packet so _tx_ring never contains a
packet-less envelope. _packet_tx_op() and get_packet_id() now also
defensively skip/return-None for packet-less envelopes.

Also handle the small race where a proof arrives between outlet.send()
registering the receipt and us installing the delivery callback: after
registration, re-read the receipt status and synthesize the
_packet_delivered() call if it's already DELIVERED.
This commit is contained in:
Jeremy O'Brien
2026-05-19 09:23:48 -04:00
parent ebf544d335
commit 794e437f6d
2 changed files with 131 additions and 57 deletions
+89 -56
View File
@@ -144,7 +144,7 @@ class MessageBase(abc.ABC):
MSGTYPE = None
"""
Defines a unique identifier for a message class.
* Must be unique within all classes registered with a ``Channel``
* Must be less than ``0xf000``. Values greater than or equal to ``0xf000`` are reserved.
"""
@@ -255,11 +255,11 @@ class Channel(contextlib.AbstractContextManager):
# The maximum window size for transfers on fast links
WINDOW_MAX_FAST = 48
# For calculating maps and guard segments, this
# must be set to the global maximum window.
WINDOW_MAX = WINDOW_MAX_FAST
# If the fast rate is sustained for this many request
# rounds, the fast link window size will be allowed.
FAST_RATE_THRESHOLD = 10
@@ -285,6 +285,7 @@ class Channel(contextlib.AbstractContextManager):
"""
self._outlet = outlet
self._lock = threading.RLock()
self._send_lock = threading.Lock()
self._tx_ring: collections.deque[Envelope] = collections.deque()
self._rx_ring: collections.deque[Envelope] = collections.deque()
self._message_callbacks: [MessageCallbackType] = []
@@ -382,27 +383,30 @@ class Channel(contextlib.AbstractContextManager):
if envelope.packet is not None:
self._outlet.set_packet_timeout_callback(envelope.packet, None)
self._outlet.set_packet_delivered_callback(envelope.packet, None)
envelope.tracked = False
for envelope in self._rx_ring:
envelope.tracked = False
self._tx_ring.clear()
self._rx_ring.clear()
def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
with self._lock:
i = 0
for existing in ring:
if envelope.sequence == existing.sequence:
RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME)
return False
if envelope.sequence < existing.sequence and not (self._next_rx_sequence - envelope.sequence) > (Channel.SEQ_MAX//2):
ring.insert(i, envelope)
envelope.tracked = True
return True
i += 1
envelope.tracked = True
ring.append(envelope)
@@ -457,7 +461,7 @@ class Channel(contextlib.AbstractContextManager):
m = e.unpack(self._message_factories)
else:
m = e.message
self._rx_ring.remove(e)
self._run_callbacks(m)
@@ -476,7 +480,7 @@ class Channel(contextlib.AbstractContextManager):
with self._lock:
outstanding = 0
for envelope in self._tx_ring:
if envelope.outlet == self._outlet:
if envelope.outlet == self._outlet:
if not envelope.packet or not self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_DELIVERED:
outstanding += 1
@@ -486,8 +490,10 @@ class Channel(contextlib.AbstractContextManager):
return True
def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]):
target_id = self._outlet.get_packet_id(packet)
with self._lock:
envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet),
envelope = next(filter(lambda e: e.packet is not None
and self._outlet.get_packet_id(e.packet) == target_id,
self._tx_ring), None)
if envelope and op(envelope):
@@ -516,7 +522,7 @@ class Channel(contextlib.AbstractContextManager):
# TODO: Remove at some point
# RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_DEBUG)
# RNS.log("Increased "+str(self)+" min window to "+str(self.window_min), RNS.LOG_DEBUG)
else:
self.fast_rate_rounds += 1
if self.window_max < Channel.WINDOW_MAX_FAST and self.fast_rate_rounds == Channel.FAST_RATE_THRESHOLD:
@@ -547,36 +553,48 @@ class Channel(contextlib.AbstractContextManager):
return to
def _packet_timeout(self, packet: TPacket):
def retry_envelope(envelope: Envelope) -> bool:
if self._outlet.get_packet_state(packet) == MessageState.MSGSTATE_DELIVERED:
return
target_id = self._outlet.get_packet_id(packet)
envelope_to_resend: Envelope | None = None
should_teardown = False
with self._lock:
envelope = next(filter(
lambda e: e.packet is not None and self._outlet.get_packet_id(e.packet) == target_id,
self._tx_ring), None)
if envelope is None:
return
if envelope.tries >= self._max_tries:
RNS.log("Retry count exceeded on "+str(self)+", tearing down Link.", RNS.LOG_ERROR)
self._shutdown() # start on separate thread?
self._outlet.timed_out()
return True
should_teardown = True
else:
envelope.tries += 1
envelope_to_resend = envelope
envelope.tries += 1
self._outlet.resend(envelope.packet)
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
self._update_packet_timeouts()
if self.window > self.window_min:
self.window -= 1
if self.window_max > (self.window_min+self.window_flexibility):
self.window_max -= 1
if self.window > self.window_min:
self.window -= 1
# TODO: Remove at some point
# RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_DEBUG)
if should_teardown:
RNS.log("Retry count exceeded on "+str(self)+", tearing down Link.", RNS.LOG_ERROR)
self._shutdown()
self._outlet.timed_out()
return
if self.window_max > (self.window_min+self.window_flexibility):
self.window_max -= 1
# TODO: Remove at some point
# RNS.log("Decreased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_DEBUG)
if envelope_to_resend is not None:
self._outlet.resend(envelope_to_resend.packet)
with self._lock:
self._outlet.set_packet_delivered_callback(envelope_to_resend.packet, self._packet_delivered)
self._outlet.set_packet_timeout_callback(
envelope_to_resend.packet, self._packet_timeout,
self._get_packet_timeout_time(envelope_to_resend.tries))
self._update_packet_timeouts()
already_delivered = (self._outlet.get_packet_state(envelope_to_resend.packet) == MessageState.MSGSTATE_DELIVERED)
# TODO: Remove at some point
# RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
return False
if self._outlet.get_packet_state(packet) != MessageState.MSGSTATE_DELIVERED:
self._packet_tx_op(packet, retry_envelope)
if already_delivered:
self._packet_delivered(envelope_to_resend.packet)
def send(self, message: MessageBase) -> Envelope:
"""
@@ -585,27 +603,39 @@ class Channel(contextlib.AbstractContextManager):
:param message: an instance of a ``MessageBase`` subclass
"""
envelope: Envelope | None = None
with self._lock:
if not self.is_ready_to_send():
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
self._next_sequence = (self._next_sequence + 1) % Channel.SEQ_MODULUS
self._emplace_envelope(envelope, self._tx_ring)
with self._send_lock:
with self._lock:
if not self.is_ready_to_send():
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
if envelope is None:
raise BlockingIOError()
reserved_sequence = self._next_sequence
envelope = Envelope(self._outlet, message=message, sequence=reserved_sequence)
envelope.pack()
if len(envelope.raw) > self._outlet.mdu:
raise ChannelException(CEType.ME_TOO_BIG,
f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
self._next_sequence = (reserved_sequence + 1) % Channel.SEQ_MODULUS
envelope.pack()
if len(envelope.raw) > self._outlet.mdu:
raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
envelope.packet = self._outlet.send(envelope.raw)
envelope.tries += 1
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
self._update_packet_timeouts()
envelope.packet = self._outlet.send(envelope.raw)
if (envelope.packet is None
or getattr(envelope.packet, "raw", None) is None
or (hasattr(envelope.packet, "receipt") and envelope.packet.receipt is None)):
with self._lock:
self._next_sequence = reserved_sequence
raise ChannelException(CEType.ME_LINK_NOT_READY, "Outlet did not transmit packet")
with self._lock:
self._emplace_envelope(envelope, self._tx_ring)
envelope.tries += 1
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
self._update_packet_timeouts()
already_delivered = (self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_DELIVERED)
# prevent _tx_ring envelope leak
if already_delivered:
self._packet_delivered(envelope.packet)
return envelope
@@ -699,7 +729,10 @@ class LinkChannelOutlet(ChannelOutletBase):
packet.receipt.set_delivery_callback(inner if callback else None)
def get_packet_id(self, packet: RNS.Packet) -> any:
if packet and hasattr(packet, "get_hash") and callable(packet.get_hash):
if (packet
and getattr(packet, "raw", None) is not None
and hasattr(packet, "get_hash")
and callable(packet.get_hash)):
return packet.get_hash()
else:
return None
+42 -1
View File
@@ -62,7 +62,7 @@ class Packet:
def set_delivered_callback(self, callback: Callable[[Packet], None]):
self.delivered_callback = callback
def delivered(self):
with self.lock:
self.state = MessageState.MSGSTATE_DELIVERED
@@ -265,6 +265,47 @@ class TestChannel(unittest.TestCase):
self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state)
self.assertFalse(envelope.tracked)
def test_send_on_failing_outlet_does_not_corrupt_state(self):
# if outlet.send() returns a packet that never reached
# the wire (LinkChannelOutlet does this when the link is not ACTIVE; the
# returned packet has raw=None), Channel.send() must not consume a
# sequence number or leave a packetless envelope in _tx_ring. Before
# the fix, the envelope was queued before outlet.send() returned, so a
# "dead" return left a raw=None envelope in the ring and silently
# advanced _next_sequence, stalling the channel on the other end.
print("Channel test send on failing outlet")
original_send = self.h.outlet.send
def ghost_send(raw):
with self.h.outlet.lock:
packet = Packet(None)
packet.state = MessageState.MSGSTATE_FAILED
self.h.outlet.packets.append(packet)
return packet
self.h.outlet.send = ghost_send
pre_sequence = self.h.channel._next_sequence
self.assertEqual(0, len(self.h.channel._tx_ring))
with self.assertRaises(RNS.Channel.ChannelException):
self.h.channel.send(MessageTest())
# Sequence must not have been consumed.
self.assertEqual(pre_sequence, self.h.channel._next_sequence)
# _tx_ring must not contain a packetless envelope.
self.assertEqual(0, len(self.h.channel._tx_ring))
# A subsequent successful send should use the same sequence number as
# was reserved for the failed attempt.
self.h.outlet.send = original_send
envelope = self.h.channel.send(MessageTest())
self.assertEqual(pre_sequence, envelope.sequence)
self.assertIsNotNone(envelope.packet)
self.assertIsNotNone(envelope.packet.raw)
self.assertTrue(envelope in self.h.channel._tx_ring)
def test_multiple_handler(self):
print("Channel test multiple handler short circuit")