from __future__ import annotations import threading import RNS from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase import RNS.Buffer from RNS.vendor import umsgpack from typing import Callable import contextlib import typing import types import time import uuid import unittest class Packet: timeout = 1.0 def __init__(self, raw: bytes): self.state = MessageState.MSGSTATE_NEW self.raw = raw self.packet_id = uuid.uuid4() self.tries = 0 self.timeout_id = None self.lock = threading.RLock() self.instances = 0 self.timeout_callback: Callable[[Packet], None] | None = None self.delivered_callback: Callable[[Packet], None] | None = None def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float): with self.lock: if timeout is not None: self.timeout = timeout self.timeout_callback = callback def send(self): self.tries += 1 self.state = MessageState.MSGSTATE_SENT def elapsed(timeout: float, timeout_id: uuid.uuid4): with self.lock: self.instances += 1 try: time.sleep(timeout) with self.lock: if self.timeout_id == timeout_id: self.timeout_id = None self.state = MessageState.MSGSTATE_FAILED if self.timeout_callback: self.timeout_callback(self) finally: with self.lock: self.instances -= 1 self.timeout_id = uuid.uuid4() threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id], daemon=True).start() def clear_timeout(self): self.timeout_id = None def set_delivered_callback(self, callback: Callable[[Packet], None]): self.delivered_callback = callback def delivered(self): with self.lock: self.state = MessageState.MSGSTATE_DELIVERED self.timeout_id = None if self.delivered_callback: self.delivered_callback(self) class ChannelOutletTest(ChannelOutletBase): def get_packet_state(self, packet: Packet) -> MessageState: return packet.state def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None, timeout: float | None = None): packet.set_timeout(callback, timeout) def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None): packet.set_delivered_callback(callback) def get_packet_id(self, packet: Packet) -> any: return packet.packet_id def __init__(self, mdu: int, rtt: float): self.link_id = uuid.uuid4() self.timeout_callbacks = 0 self._mdu = mdu self._rtt = rtt self._usable = True self.packets = [] self.lock = threading.RLock() self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None def send(self, raw: bytes) -> Packet: with self.lock: packet = Packet(raw) packet.send() self.packets.append(packet) return packet def resend(self, packet: Packet) -> Packet: with self.lock: packet.send() return packet @property def mdu(self): return self._mdu @property def rtt(self): return self._rtt @property def is_usable(self): return self._usable def timed_out(self): self.timeout_callbacks += 1 def __str__(self): return str(self.link_id) class MessageTest(MessageBase): MSGTYPE = 0xabcd def __init__(self): self.id = str(uuid.uuid4()) self.data = "test" self.not_serialized = str(uuid.uuid4()) def pack(self) -> bytes: return umsgpack.packb((self.id, self.data)) def unpack(self, raw): self.id, self.data = umsgpack.unpackb(raw) class SystemMessage(MessageBase): MSGTYPE = 0xf000 def pack(self) -> bytes: return bytes() def unpack(self, raw): pass class ProtocolHarness(contextlib.AbstractContextManager): def __init__(self, rtt: float): self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) self.channel = Channel(self.outlet) Packet.timeout = self.channel._get_packet_timeout_time(1) def cleanup(self): self.channel._shutdown() def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, __traceback: types.TracebackType) -> bool: # self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})") self.cleanup() return False class TestChannel(unittest.TestCase): def setUp(self) -> None: print("") self.rtt = 0.01 self.h = ProtocolHarness(self.rtt) def tearDown(self) -> None: self.h.cleanup() def test_send_one_retry(self): print("Channel test one retry") message = MessageTest() self.assertEqual(0, len(self.h.outlet.packets)) envelope = self.h.channel.send(message) self.assertIsNotNone(envelope) self.assertIsNotNone(envelope.raw) self.assertEqual(1, len(self.h.outlet.packets)) self.assertIsNotNone(envelope.packet) self.assertTrue(envelope in self.h.channel._tx_ring) self.assertTrue(envelope.tracked) packet = self.h.outlet.packets[0] self.assertEqual(envelope.packet, packet) self.assertEqual(1, envelope.tries) self.assertEqual(1, packet.tries) self.assertEqual(1, packet.instances) self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) self.assertEqual(envelope.raw, packet.raw) time.sleep(self.h.channel._get_packet_timeout_time(1) * 1.1) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(2, envelope.tries) self.assertEqual(2, packet.tries) self.assertEqual(1, packet.instances) time.sleep(self.h.channel._get_packet_timeout_time(2) * 1.1) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(self.h.outlet.packets[0], packet) self.assertEqual(3, envelope.tries) self.assertEqual(3, packet.tries) self.assertEqual(1, packet.instances) self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) packet.delivered() self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) time.sleep(self.h.channel._get_packet_timeout_time(3) * 1.1) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(3, envelope.tries) self.assertEqual(3, packet.tries) self.assertEqual(0, packet.instances) self.assertFalse(envelope.tracked) def test_send_timeout(self): print("Channel test retry count exceeded") message = MessageTest() self.assertEqual(0, len(self.h.outlet.packets)) envelope = self.h.channel.send(message) self.assertIsNotNone(envelope) self.assertIsNotNone(envelope.raw) self.assertEqual(1, len(self.h.outlet.packets)) self.assertIsNotNone(envelope.packet) self.assertTrue(envelope in self.h.channel._tx_ring) self.assertTrue(envelope.tracked) packet = self.h.outlet.packets[0] self.assertEqual(envelope.packet, packet) self.assertEqual(1, envelope.tries) self.assertEqual(1, packet.tries) self.assertEqual(1, packet.instances) self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) self.assertEqual(envelope.raw, packet.raw) time.sleep(self.h.channel._get_packet_timeout_time(1)) time.sleep(self.h.channel._get_packet_timeout_time(2)) time.sleep(self.h.channel._get_packet_timeout_time(3)) time.sleep(self.h.channel._get_packet_timeout_time(4)) time.sleep(self.h.channel._get_packet_timeout_time(5) * 1.1) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(5, envelope.tries) self.assertEqual(5, packet.tries) self.assertEqual(0, packet.instances) self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) self.assertFalse(envelope.tracked) def test_send_on_failing_outlet_does_not_corrupt_state(self): # if outlet.send() returns a packet that never reached # the wire (LinkChannelOutlet does this when the link is not ACTIVE; the # returned packet has raw=None), Channel.send() must not consume a # sequence number or leave a packetless envelope in _tx_ring. Before # the fix, the envelope was queued before outlet.send() returned, so a # "dead" return left a raw=None envelope in the ring and silently # advanced _next_sequence, stalling the channel on the other end. print("Channel test send on failing outlet") original_send = self.h.outlet.send def ghost_send(raw): with self.h.outlet.lock: packet = Packet(None) packet.state = MessageState.MSGSTATE_FAILED self.h.outlet.packets.append(packet) return packet self.h.outlet.send = ghost_send pre_sequence = self.h.channel._next_sequence self.assertEqual(0, len(self.h.channel._tx_ring)) with self.assertRaises(RNS.Channel.ChannelException): self.h.channel.send(MessageTest()) # Sequence must not have been consumed. self.assertEqual(pre_sequence, self.h.channel._next_sequence) # _tx_ring must not contain a packetless envelope. self.assertEqual(0, len(self.h.channel._tx_ring)) # A subsequent successful send should use the same sequence number as # was reserved for the failed attempt. self.h.outlet.send = original_send envelope = self.h.channel.send(MessageTest()) self.assertEqual(pre_sequence, envelope.sequence) self.assertIsNotNone(envelope.packet) self.assertIsNotNone(envelope.packet.raw) self.assertTrue(envelope in self.h.channel._tx_ring) def test_multiple_handler(self): print("Channel test multiple handler short circuit") handler1_called = 0 handler1_return = True handler2_called = 0 def handler1(msg: MessageBase): nonlocal handler1_called, handler1_return self.assertIsInstance(msg, MessageTest) handler1_called += 1 return handler1_return def handler2(msg: MessageBase): nonlocal handler2_called self.assertIsInstance(msg, MessageTest) handler2_called += 1 message = MessageTest() self.h.channel.register_message_type(MessageTest) self.h.channel.add_message_handler(handler1) self.h.channel.add_message_handler(handler2) envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0) raw = envelope.pack() self.h.channel._receive(raw) time.sleep(0.5) self.assertEqual(1, handler1_called) self.assertEqual(0, handler2_called) handler1_return = False envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1) raw = envelope.pack() self.h.channel._receive(raw) time.sleep(0.5) self.assertEqual(2, handler1_called) self.assertEqual(1, handler2_called) def test_system_message_check(self): print("Channel test register system message") with self.assertRaises(RNS.Channel.ChannelException): self.h.channel.register_message_type(SystemMessage) self.h.channel._register_message_type(SystemMessage, is_system_type=True) def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): decoded: [MessageBase] = [] def handle_message(message: MessageBase): decoded.append(message) self.h.channel.register_message_type(message.__class__) self.h.channel.add_message_handler(handle_message) self.assertEqual(len(self.h.outlet.packets), 0) envelope = self.h.channel.send(message) time.sleep(self.h.channel._get_packet_timeout_time(1) * 0.5) self.assertIsNotNone(envelope) self.assertIsNotNone(envelope.raw) self.assertEqual(1, len(self.h.outlet.packets)) self.assertIsNotNone(envelope.packet) self.assertTrue(envelope in self.h.channel._tx_ring) self.assertTrue(envelope.tracked) packet = self.h.outlet.packets[0] self.assertEqual(envelope.packet, packet) self.assertEqual(1, envelope.tries) self.assertEqual(1, packet.tries) self.assertEqual(1, packet.instances) self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) self.assertEqual(envelope.raw, packet.raw) packet.delivered() self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) time.sleep(self.h.channel._get_packet_timeout_time(1)) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(1, envelope.tries) self.assertEqual(1, packet.tries) self.assertEqual(0, packet.instances) self.assertFalse(envelope.tracked) self.assertEqual(len(self.h.outlet.packets), 1) self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) self.assertFalse(envelope.tracked) self.assertEqual(0, len(decoded)) self.h.channel._receive(packet.raw) time.sleep(0.5) self.assertEqual(1, len(decoded)) rx_message = decoded[0] self.assertIsNotNone(rx_message) self.assertIsInstance(rx_message, message.__class__) checker(rx_message) def test_send_receive_message_test(self): print("Channel test send and receive message") message = MessageTest() def check(rx_message: MessageBase): self.assertIsInstance(rx_message, message.__class__) self.assertEqual(message.id, rx_message.id) self.assertEqual(message.data, rx_message.data) self.assertNotEqual(message.not_serialized, rx_message.not_serialized) self.eat_own_dog_food(message, check) def test_buffer_small_bidirectional(self): data = "Hello\n" with RNS.Buffer.create_bidirectional_buffer(0, 0, self.h.channel) as buffer: count = buffer.write(data.encode("utf-8")) buffer.flush() self.assertEqual(len(data), count) self.assertEqual(1, len(self.h.outlet.packets)) packet = self.h.outlet.packets[0] self.h.channel._receive(packet.raw) time.sleep(0.2) result = buffer.readline() self.assertIsNotNone(result) self.assertEqual(len(result), len(data)) decoded = result.decode("utf-8") self.assertEqual(data, decoded) def test_buffer_big(self): writer = RNS.Buffer.create_writer(15, self.h.channel) reader = RNS.Buffer.create_reader(15, self.h.channel) data = "01234556789"*1024*5 # 50 KB count = 0 write_finished = False def write_thread(): nonlocal count, write_finished count = writer.write(data.encode("utf-8")) writer.flush() writer.close() # TODO: Workaround for https://github.com/python/cpython/issues/138720 write_finished = True threading.Thread(target=write_thread, name="Write Thread", daemon=True).start() while not write_finished or next(filter(lambda x: x.state != MessageState.MSGSTATE_DELIVERED, self.h.outlet.packets), None) is not None: with self.h.outlet.lock: for packet in self.h.outlet.packets: if packet.state != MessageState.MSGSTATE_DELIVERED: self.h.channel._receive(packet.raw) packet.delivered() time.sleep(0.0001) self.assertEqual(len(data), count) read_finished = False result = bytes() def read_thread(): nonlocal read_finished, result result = reader.read() read_finished = True threading.Thread(target=read_thread, name="Read Thread", daemon=True).start() timeout_at = time.time() + 7 while not read_finished and time.time() < timeout_at: time.sleep(0.001) self.assertTrue(read_finished) self.assertEqual(len(data), len(result)) decoded = result.decode("utf-8") self.assertSequenceEqual(data, decoded) def test_buffer_small_with_callback(self): callbacks = 0 last_cb_value = None def callback(ready: int): nonlocal callbacks, last_cb_value callbacks += 1 last_cb_value = ready data = "Hello\n" with RNS.RawChannelWriter(0, self.h.channel) as writer, RNS.RawChannelReader(0, self.h.channel) as reader: reader.add_ready_callback(callback) count = writer.write(data.encode("utf-8")) writer.flush() self.assertEqual(len(data), count) self.assertEqual(1, len(self.h.outlet.packets)) packet = self.h.outlet.packets[0] self.h.channel._receive(packet.raw) packet.delivered() self.assertEqual(1, callbacks) self.assertEqual(len(data), last_cb_value) result = reader.readline() self.assertIsNotNone(result) self.assertEqual(len(result), len(data)) decoded = result.decode("utf-8") self.assertEqual(data, decoded) self.assertEqual(1, len(self.h.outlet.packets)) result = reader.read(1) self.assertIsNone(result) self.assertTrue(self.h.channel.is_ready_to_send()) writer.close() self.assertEqual(2, len(self.h.outlet.packets)) packet = self.h.outlet.packets[1] self.h.channel._receive(packet.raw) packet.delivered() result = reader.read(1) self.assertIsNotNone(result) self.assertTrue(len(result) == 0) if __name__ == '__main__': unittest.main(verbosity=2)