mirror of
https://github.com/markqvist/Reticulum.git
synced 2026-06-08 14:11:53 -07:00
Channel: prevent sequence holes and ghost envelopes when sending on a dying outlet
RNSChannelOutlet.send() can return a packet that never reached the wire (link not ACTIVE, no capable interface, etc). The old Channel.send() queued the envelope in _tx_ring before calling outlet.send(), then tried to rewind _next_sequence and remove the envelope if the outlet returned a failed packet. Two problems: - Between queueing and outlet.send() returning, _tx_ring held an envelope with packet.raw=None. Any worker thread iterating the ring (timeout fire, proof callback) crashed in get_packet_id's packet.get_hash() with a TypeError on None.raw. - The rewind was only safe for a single-threaded sender: it checked "is _next_sequence one past mine?" and skipped the rewind otherwise. Under concurrent senders, the rewind silently failed, leaving a hole in the on-wire sequence stream. The receiver's contiguous seqnum rule then stalled the channel permanently with no error. This fix serializes the reservation-and-transmit pair with a per-channel _send_lock so the rewind is always correct, and defers queueing until outlet.send() returns a real packet so _tx_ring never contains a packet-less envelope. _packet_tx_op() and get_packet_id() now also defensively skip/return-None for packet-less envelopes. Also handle the small race where a proof arrives between outlet.send() registering the receipt and us installing the delivery callback: after registration, re-read the receipt status and synthesize the _packet_delivered() call if it's already DELIVERED.
This commit is contained in:
+89
-56
@@ -144,7 +144,7 @@ class MessageBase(abc.ABC):
|
||||
MSGTYPE = None
|
||||
"""
|
||||
Defines a unique identifier for a message class.
|
||||
|
||||
|
||||
* Must be unique within all classes registered with a ``Channel``
|
||||
* Must be less than ``0xf000``. Values greater than or equal to ``0xf000`` are reserved.
|
||||
"""
|
||||
@@ -255,11 +255,11 @@ class Channel(contextlib.AbstractContextManager):
|
||||
|
||||
# The maximum window size for transfers on fast links
|
||||
WINDOW_MAX_FAST = 48
|
||||
|
||||
|
||||
# For calculating maps and guard segments, this
|
||||
# must be set to the global maximum window.
|
||||
WINDOW_MAX = WINDOW_MAX_FAST
|
||||
|
||||
|
||||
# If the fast rate is sustained for this many request
|
||||
# rounds, the fast link window size will be allowed.
|
||||
FAST_RATE_THRESHOLD = 10
|
||||
@@ -285,6 +285,7 @@ class Channel(contextlib.AbstractContextManager):
|
||||
"""
|
||||
self._outlet = outlet
|
||||
self._lock = threading.RLock()
|
||||
self._send_lock = threading.Lock()
|
||||
self._tx_ring: collections.deque[Envelope] = collections.deque()
|
||||
self._rx_ring: collections.deque[Envelope] = collections.deque()
|
||||
self._message_callbacks: [MessageCallbackType] = []
|
||||
@@ -382,27 +383,30 @@ class Channel(contextlib.AbstractContextManager):
|
||||
if envelope.packet is not None:
|
||||
self._outlet.set_packet_timeout_callback(envelope.packet, None)
|
||||
self._outlet.set_packet_delivered_callback(envelope.packet, None)
|
||||
envelope.tracked = False
|
||||
for envelope in self._rx_ring:
|
||||
envelope.tracked = False
|
||||
self._tx_ring.clear()
|
||||
self._rx_ring.clear()
|
||||
|
||||
def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
|
||||
with self._lock:
|
||||
i = 0
|
||||
|
||||
|
||||
for existing in ring:
|
||||
|
||||
if envelope.sequence == existing.sequence:
|
||||
RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME)
|
||||
return False
|
||||
|
||||
|
||||
if envelope.sequence < existing.sequence and not (self._next_rx_sequence - envelope.sequence) > (Channel.SEQ_MAX//2):
|
||||
ring.insert(i, envelope)
|
||||
|
||||
envelope.tracked = True
|
||||
return True
|
||||
|
||||
|
||||
i += 1
|
||||
|
||||
|
||||
envelope.tracked = True
|
||||
ring.append(envelope)
|
||||
|
||||
@@ -457,7 +461,7 @@ class Channel(contextlib.AbstractContextManager):
|
||||
m = e.unpack(self._message_factories)
|
||||
else:
|
||||
m = e.message
|
||||
|
||||
|
||||
self._rx_ring.remove(e)
|
||||
self._run_callbacks(m)
|
||||
|
||||
@@ -476,7 +480,7 @@ class Channel(contextlib.AbstractContextManager):
|
||||
with self._lock:
|
||||
outstanding = 0
|
||||
for envelope in self._tx_ring:
|
||||
if envelope.outlet == self._outlet:
|
||||
if envelope.outlet == self._outlet:
|
||||
if not envelope.packet or not self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_DELIVERED:
|
||||
outstanding += 1
|
||||
|
||||
@@ -486,8 +490,10 @@ class Channel(contextlib.AbstractContextManager):
|
||||
return True
|
||||
|
||||
def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]):
|
||||
target_id = self._outlet.get_packet_id(packet)
|
||||
with self._lock:
|
||||
envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet),
|
||||
envelope = next(filter(lambda e: e.packet is not None
|
||||
and self._outlet.get_packet_id(e.packet) == target_id,
|
||||
self._tx_ring), None)
|
||||
|
||||
if envelope and op(envelope):
|
||||
@@ -516,7 +522,7 @@ class Channel(contextlib.AbstractContextManager):
|
||||
# TODO: Remove at some point
|
||||
# RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_DEBUG)
|
||||
# RNS.log("Increased "+str(self)+" min window to "+str(self.window_min), RNS.LOG_DEBUG)
|
||||
|
||||
|
||||
else:
|
||||
self.fast_rate_rounds += 1
|
||||
if self.window_max < Channel.WINDOW_MAX_FAST and self.fast_rate_rounds == Channel.FAST_RATE_THRESHOLD:
|
||||
@@ -547,36 +553,48 @@ class Channel(contextlib.AbstractContextManager):
|
||||
return to
|
||||
|
||||
def _packet_timeout(self, packet: TPacket):
|
||||
def retry_envelope(envelope: Envelope) -> bool:
|
||||
if self._outlet.get_packet_state(packet) == MessageState.MSGSTATE_DELIVERED:
|
||||
return
|
||||
|
||||
target_id = self._outlet.get_packet_id(packet)
|
||||
envelope_to_resend: Envelope | None = None
|
||||
should_teardown = False
|
||||
with self._lock:
|
||||
envelope = next(filter(
|
||||
lambda e: e.packet is not None and self._outlet.get_packet_id(e.packet) == target_id,
|
||||
self._tx_ring), None)
|
||||
if envelope is None:
|
||||
return
|
||||
|
||||
if envelope.tries >= self._max_tries:
|
||||
RNS.log("Retry count exceeded on "+str(self)+", tearing down Link.", RNS.LOG_ERROR)
|
||||
self._shutdown() # start on separate thread?
|
||||
self._outlet.timed_out()
|
||||
return True
|
||||
should_teardown = True
|
||||
else:
|
||||
envelope.tries += 1
|
||||
envelope_to_resend = envelope
|
||||
|
||||
envelope.tries += 1
|
||||
self._outlet.resend(envelope.packet)
|
||||
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
|
||||
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
|
||||
self._update_packet_timeouts()
|
||||
if self.window > self.window_min:
|
||||
self.window -= 1
|
||||
if self.window_max > (self.window_min+self.window_flexibility):
|
||||
self.window_max -= 1
|
||||
|
||||
if self.window > self.window_min:
|
||||
self.window -= 1
|
||||
# TODO: Remove at some point
|
||||
# RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_DEBUG)
|
||||
if should_teardown:
|
||||
RNS.log("Retry count exceeded on "+str(self)+", tearing down Link.", RNS.LOG_ERROR)
|
||||
self._shutdown()
|
||||
self._outlet.timed_out()
|
||||
return
|
||||
|
||||
if self.window_max > (self.window_min+self.window_flexibility):
|
||||
self.window_max -= 1
|
||||
# TODO: Remove at some point
|
||||
# RNS.log("Decreased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_DEBUG)
|
||||
if envelope_to_resend is not None:
|
||||
self._outlet.resend(envelope_to_resend.packet)
|
||||
with self._lock:
|
||||
self._outlet.set_packet_delivered_callback(envelope_to_resend.packet, self._packet_delivered)
|
||||
self._outlet.set_packet_timeout_callback(
|
||||
envelope_to_resend.packet, self._packet_timeout,
|
||||
self._get_packet_timeout_time(envelope_to_resend.tries))
|
||||
self._update_packet_timeouts()
|
||||
already_delivered = (self._outlet.get_packet_state(envelope_to_resend.packet) == MessageState.MSGSTATE_DELIVERED)
|
||||
|
||||
# TODO: Remove at some point
|
||||
# RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
|
||||
|
||||
return False
|
||||
|
||||
if self._outlet.get_packet_state(packet) != MessageState.MSGSTATE_DELIVERED:
|
||||
self._packet_tx_op(packet, retry_envelope)
|
||||
if already_delivered:
|
||||
self._packet_delivered(envelope_to_resend.packet)
|
||||
|
||||
def send(self, message: MessageBase) -> Envelope:
|
||||
"""
|
||||
@@ -585,27 +603,39 @@ class Channel(contextlib.AbstractContextManager):
|
||||
|
||||
:param message: an instance of a ``MessageBase`` subclass
|
||||
"""
|
||||
envelope: Envelope | None = None
|
||||
with self._lock:
|
||||
if not self.is_ready_to_send():
|
||||
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
|
||||
|
||||
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
|
||||
self._next_sequence = (self._next_sequence + 1) % Channel.SEQ_MODULUS
|
||||
self._emplace_envelope(envelope, self._tx_ring)
|
||||
with self._send_lock:
|
||||
with self._lock:
|
||||
if not self.is_ready_to_send():
|
||||
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
|
||||
|
||||
if envelope is None:
|
||||
raise BlockingIOError()
|
||||
reserved_sequence = self._next_sequence
|
||||
envelope = Envelope(self._outlet, message=message, sequence=reserved_sequence)
|
||||
envelope.pack()
|
||||
if len(envelope.raw) > self._outlet.mdu:
|
||||
raise ChannelException(CEType.ME_TOO_BIG,
|
||||
f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
|
||||
self._next_sequence = (reserved_sequence + 1) % Channel.SEQ_MODULUS
|
||||
|
||||
envelope.pack()
|
||||
if len(envelope.raw) > self._outlet.mdu:
|
||||
raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
|
||||
|
||||
envelope.packet = self._outlet.send(envelope.raw)
|
||||
envelope.tries += 1
|
||||
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
|
||||
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
|
||||
self._update_packet_timeouts()
|
||||
envelope.packet = self._outlet.send(envelope.raw)
|
||||
|
||||
if (envelope.packet is None
|
||||
or getattr(envelope.packet, "raw", None) is None
|
||||
or (hasattr(envelope.packet, "receipt") and envelope.packet.receipt is None)):
|
||||
with self._lock:
|
||||
self._next_sequence = reserved_sequence
|
||||
raise ChannelException(CEType.ME_LINK_NOT_READY, "Outlet did not transmit packet")
|
||||
|
||||
with self._lock:
|
||||
self._emplace_envelope(envelope, self._tx_ring)
|
||||
envelope.tries += 1
|
||||
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
|
||||
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
|
||||
self._update_packet_timeouts()
|
||||
already_delivered = (self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_DELIVERED)
|
||||
|
||||
# prevent _tx_ring envelope leak
|
||||
if already_delivered:
|
||||
self._packet_delivered(envelope.packet)
|
||||
|
||||
return envelope
|
||||
|
||||
@@ -699,7 +729,10 @@ class LinkChannelOutlet(ChannelOutletBase):
|
||||
packet.receipt.set_delivery_callback(inner if callback else None)
|
||||
|
||||
def get_packet_id(self, packet: RNS.Packet) -> any:
|
||||
if packet and hasattr(packet, "get_hash") and callable(packet.get_hash):
|
||||
if (packet
|
||||
and getattr(packet, "raw", None) is not None
|
||||
and hasattr(packet, "get_hash")
|
||||
and callable(packet.get_hash)):
|
||||
return packet.get_hash()
|
||||
else:
|
||||
return None
|
||||
|
||||
+42
-1
@@ -62,7 +62,7 @@ class Packet:
|
||||
|
||||
def set_delivered_callback(self, callback: Callable[[Packet], None]):
|
||||
self.delivered_callback = callback
|
||||
|
||||
|
||||
def delivered(self):
|
||||
with self.lock:
|
||||
self.state = MessageState.MSGSTATE_DELIVERED
|
||||
@@ -265,6 +265,47 @@ class TestChannel(unittest.TestCase):
|
||||
self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state)
|
||||
self.assertFalse(envelope.tracked)
|
||||
|
||||
def test_send_on_failing_outlet_does_not_corrupt_state(self):
|
||||
# if outlet.send() returns a packet that never reached
|
||||
# the wire (LinkChannelOutlet does this when the link is not ACTIVE; the
|
||||
# returned packet has raw=None), Channel.send() must not consume a
|
||||
# sequence number or leave a packetless envelope in _tx_ring. Before
|
||||
# the fix, the envelope was queued before outlet.send() returned, so a
|
||||
# "dead" return left a raw=None envelope in the ring and silently
|
||||
# advanced _next_sequence, stalling the channel on the other end.
|
||||
print("Channel test send on failing outlet")
|
||||
|
||||
original_send = self.h.outlet.send
|
||||
|
||||
def ghost_send(raw):
|
||||
with self.h.outlet.lock:
|
||||
packet = Packet(None)
|
||||
packet.state = MessageState.MSGSTATE_FAILED
|
||||
self.h.outlet.packets.append(packet)
|
||||
return packet
|
||||
|
||||
self.h.outlet.send = ghost_send
|
||||
|
||||
pre_sequence = self.h.channel._next_sequence
|
||||
self.assertEqual(0, len(self.h.channel._tx_ring))
|
||||
|
||||
with self.assertRaises(RNS.Channel.ChannelException):
|
||||
self.h.channel.send(MessageTest())
|
||||
|
||||
# Sequence must not have been consumed.
|
||||
self.assertEqual(pre_sequence, self.h.channel._next_sequence)
|
||||
# _tx_ring must not contain a packetless envelope.
|
||||
self.assertEqual(0, len(self.h.channel._tx_ring))
|
||||
|
||||
# A subsequent successful send should use the same sequence number as
|
||||
# was reserved for the failed attempt.
|
||||
self.h.outlet.send = original_send
|
||||
envelope = self.h.channel.send(MessageTest())
|
||||
self.assertEqual(pre_sequence, envelope.sequence)
|
||||
self.assertIsNotNone(envelope.packet)
|
||||
self.assertIsNotNone(envelope.packet.raw)
|
||||
self.assertTrue(envelope in self.h.channel._tx_ring)
|
||||
|
||||
def test_multiple_handler(self):
|
||||
print("Channel test multiple handler short circuit")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user