Files
intercept/tests/test_bluetooth_aggregator.py
T
James Smith 276b151e9e test: repair stale assertions in bluetooth group and deauth detector
Bluetooth aggregator/api/heuristics tests updated to current behavior;
deauth detector integration test rewritten to exercise the tracker and
alert path directly instead of patching __globals__ (read-only on
Python 3.14).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 15:44:59 +01:00

595 lines
20 KiB
Python

"""Unit tests for Bluetooth device aggregation."""
import dataclasses
from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
from utils.bluetooth.aggregator import DeviceAggregator
from utils.bluetooth.constants import (
DEVICE_STALE_TIMEOUT as DEVICE_STALE_SECONDS,
)
from utils.bluetooth.constants import (
MAX_RSSI_SAMPLES,
)
from utils.bluetooth.models import BTObservation
@pytest.fixture
def aggregator():
"""Create a fresh DeviceAggregator for testing."""
return DeviceAggregator()
@pytest.fixture
def sample_observation():
"""Create a sample BLE observation."""
return BTObservation(
timestamp=datetime.now(),
address="AA:BB:CC:DD:EE:FF",
address_type="public",
rssi=-55,
tx_power=None,
name="Test Device",
manufacturer_id=76, # Apple
manufacturer_data=None,
service_uuids=["0000180f-0000-1000-8000-00805f9b34fb"],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
class TestDeviceAggregator:
"""Tests for DeviceAggregator class."""
def test_ingest_single_observation(self, aggregator, sample_observation):
"""Test ingesting a single observation creates device aggregate."""
aggregator.ingest(sample_observation)
devices = aggregator.get_all_devices()
assert len(devices) == 1
device = devices[0]
assert device.address == "AA:BB:CC:DD:EE:FF"
assert device.name == "Test Device"
assert device.rssi_current == -55
assert device.seen_count == 1
def test_ingest_multiple_observations_same_device(self, aggregator, sample_observation):
"""Test multiple observations for same device aggregate correctly."""
# Ingest multiple observations with varying RSSI
rssi_values = [-55, -60, -50, -58, -52]
for rssi in rssi_values:
obs = BTObservation(
timestamp=datetime.now(),
address=sample_observation.address,
address_type=sample_observation.address_type,
rssi=rssi,
tx_power=None,
name=sample_observation.name,
manufacturer_id=sample_observation.manufacturer_id,
manufacturer_data=None,
service_uuids=sample_observation.service_uuids,
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
devices = aggregator.get_all_devices()
assert len(devices) == 1
device = devices[0]
assert device.seen_count == 5
assert device.rssi_current == rssi_values[-1]
assert len(device.rssi_samples) == 5
# Check RSSI stats
assert device.rssi_min == -60
assert device.rssi_max == -50
def test_rssi_median_calculation(self, aggregator, sample_observation):
"""Test RSSI median is calculated correctly."""
rssi_values = [-70, -60, -50, -55, -65] # Sorted: -70, -65, -60, -55, -50 -> median -60
for rssi in rssi_values:
obs = BTObservation(
timestamp=datetime.now(),
address=sample_observation.address,
address_type="public",
rssi=rssi,
tx_power=None,
name="Test",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
device = aggregator.get_all_devices()[0]
assert device.rssi_median == -60.0
def test_rssi_samples_limited(self, aggregator, sample_observation):
"""Test RSSI samples are limited to MAX_RSSI_SAMPLES."""
for i in range(MAX_RSSI_SAMPLES + 50):
obs = BTObservation(
timestamp=datetime.now(),
address=sample_observation.address,
address_type="public",
rssi=-50 - (i % 30),
tx_power=None,
name="Test",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
device = aggregator.get_all_devices()[0]
assert len(device.rssi_samples) <= MAX_RSSI_SAMPLES
def test_protocol_detection_ble(self, aggregator):
"""Test BLE protocol detection."""
obs = BTObservation(
timestamp=datetime.now(),
address="AA:BB:CC:DD:EE:FF",
address_type="random", # Random address indicates BLE
rssi=-60,
tx_power=-8,
name="BLE Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=["0000180a-0000-1000-8000-00805f9b34fb"],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
device = aggregator.get_all_devices()[0]
assert device.protocol == "ble"
def test_protocol_detection_classic(self, aggregator):
"""Test Classic Bluetooth protocol detection."""
obs = BTObservation(
timestamp=datetime.now(),
address="AA:BB:CC:DD:EE:FF",
address_type="public",
rssi=-60,
tx_power=None,
name="Classic Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=0x240404, # Audio device
major_class="audio_video",
minor_class="headphones",
)
aggregator.ingest(obs)
device = aggregator.get_all_devices()[0]
assert device.protocol == "classic"
class TestRangeBandEstimation:
"""Tests for range band estimation."""
def test_range_band_very_close(self, aggregator):
"""Test very close range band detection."""
obs = BTObservation(
timestamp=datetime.now(),
address="AA:BB:CC:DD:EE:FF",
address_type="public",
rssi=-35, # Very strong signal
tx_power=None,
name="Close Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
# Add multiple samples to build confidence
for _ in range(10):
aggregator.ingest(obs)
device = aggregator.get_all_devices()[0]
assert device.range_band == "very_close"
def test_range_band_close(self, aggregator):
"""Test close range band detection."""
for rssi in [-45, -48, -50, -47, -49]:
obs = BTObservation(
timestamp=datetime.now(),
address="AA:BB:CC:DD:EE:FF",
address_type="public",
rssi=rssi,
tx_power=None,
name="Close Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
device = aggregator.get_all_devices()[0]
assert device.range_band in ["very_close", "close"]
def test_range_band_far(self, aggregator):
"""Test far range band detection."""
for rssi in [-75, -78, -80, -77, -79]:
obs = BTObservation(
timestamp=datetime.now(),
address="AA:BB:CC:DD:EE:FF",
address_type="public",
rssi=rssi,
tx_power=None,
name="Far Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
device = aggregator.get_all_devices()[0]
assert device.range_band in ["nearby", "far"]
def test_range_band_unknown_low_confidence(self, aggregator):
"""Test unknown range band with insufficient data."""
obs = BTObservation(
timestamp=datetime.now(),
address="AA:BB:CC:DD:EE:FF",
address_type="public",
rssi=-60,
tx_power=None,
name="Unknown Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
device = aggregator.get_all_devices()[0]
# With only one sample, confidence is low
assert device.rssi_confidence < 0.5
class TestBaselineManagement:
"""Tests for baseline functionality."""
def test_set_baseline(self, aggregator, sample_observation):
"""Test setting a baseline from current devices."""
aggregator.ingest(sample_observation)
count = aggregator.set_baseline()
assert count == 1
assert aggregator.has_baseline is True
def test_clear_baseline(self, aggregator, sample_observation):
"""Test clearing the baseline."""
aggregator.ingest(sample_observation)
aggregator.set_baseline()
aggregator.clear_baseline()
assert aggregator.has_baseline is False
def test_is_new_device(self, aggregator, sample_observation):
"""Test detection of new devices vs baseline."""
# Add first device and set baseline
aggregator.ingest(sample_observation)
aggregator.set_baseline()
# Add new device
new_obs = BTObservation(
timestamp=datetime.now(),
address="11:22:33:44:55:66",
address_type="public",
rssi=-60,
tx_power=None,
name="New Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(new_obs)
devices = aggregator.get_all_devices()
new_device = next(d for d in devices if d.address == "11:22:33:44:55:66")
assert new_device.is_new is True
# Original device should not be new
original = next(d for d in devices if d.address == sample_observation.address)
assert original.is_new is False
class TestDevicePruning:
"""Tests for stale device pruning."""
def test_prune_stale_devices(self, aggregator):
"""Test that stale devices are removed."""
# Create an old observation
old_time = datetime.now() - timedelta(seconds=DEVICE_STALE_SECONDS + 60)
old_obs = BTObservation(
timestamp=old_time,
address="AA:BB:CC:DD:EE:FF",
address_type="public",
rssi=-60,
tx_power=None,
name="Old Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(old_obs)
# Create a recent observation for different device
recent_obs = BTObservation(
timestamp=datetime.now(),
address="11:22:33:44:55:66",
address_type="public",
rssi=-55,
tx_power=None,
name="Recent Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(recent_obs)
# Prune stale devices
pruned = aggregator.prune_stale_devices()
assert pruned == 1
devices = aggregator.get_all_devices()
assert len(devices) == 1
assert devices[0].address == "11:22:33:44:55:66"
class TestDeviceFiltering:
"""Tests for device filtering and sorting."""
def test_filter_by_protocol(self, aggregator):
"""Test filtering devices by protocol."""
# Add BLE device
ble_obs = BTObservation(
timestamp=datetime.now(),
address="AA:BB:CC:DD:EE:FF",
address_type="random",
rssi=-60,
tx_power=-8,
name="BLE Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=["0000180a-0000-1000-8000-00805f9b34fb"],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(ble_obs)
# Add Classic device
classic_obs = BTObservation(
timestamp=datetime.now(),
address="11:22:33:44:55:66",
address_type="public",
rssi=-55,
tx_power=None,
name="Classic Device",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=0x240404,
major_class="audio_video",
minor_class=None,
)
aggregator.ingest(classic_obs)
# Filter by BLE — get_all_devices() takes no args; filter in test
all_devices = aggregator.get_all_devices()
ble_devices = [d for d in all_devices if d.protocol == "ble"]
assert len(ble_devices) == 1
assert ble_devices[0].protocol == "ble"
# Filter by Classic
classic_devices = [d for d in all_devices if d.protocol == "classic"]
assert len(classic_devices) == 1
assert classic_devices[0].protocol == "classic"
def test_filter_by_min_rssi(self, aggregator):
"""Test filtering devices by minimum RSSI."""
for i, rssi in enumerate([-50, -70, -90]):
obs = BTObservation(
timestamp=datetime.now(),
address=f"AA:BB:CC:DD:EE:{i:02X}",
address_type="public",
rssi=rssi,
tx_power=None,
name=f"Device {i}",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
# Filter by min RSSI -60 — get_all_devices() takes no args; filter in test
all_devices = aggregator.get_all_devices()
strong_devices = [d for d in all_devices if d.rssi_current is not None and d.rssi_current >= -60]
assert len(strong_devices) == 1
assert strong_devices[0].rssi_current == -50
def test_sort_by_rssi(self, aggregator):
"""Test sorting devices by RSSI."""
for rssi in [-70, -50, -90, -60]:
obs = BTObservation(
timestamp=datetime.now(),
address=f"AA:BB:CC:DD:{abs(rssi):02X}:FF",
address_type="public",
rssi=rssi,
tx_power=None,
name=f"Device RSSI {rssi}",
manufacturer_id=None,
manufacturer_data=None,
service_uuids=[],
service_data={},
appearance=None,
is_connectable=True,
is_paired=False,
is_connected=False,
class_of_device=None,
major_class=None,
minor_class=None,
)
aggregator.ingest(obs)
# Sort by RSSI (strongest first) — get_all_devices() takes no args; sort in test
all_devices = aggregator.get_all_devices()
devices = sorted(all_devices, key=lambda d: d.rssi_current or -999, reverse=True)
rssi_values = [d.rssi_current for d in devices]
assert rssi_values == [-50, -60, -70, -90]
class TestTrackerDetectionOptimization:
"""Tests for tracker detection payload fingerprint optimization."""
def test_tracker_detection_skipped_when_payload_unchanged(self, aggregator, sample_observation):
"""detect_tracker must not be called on the second observation if payload is identical."""
aggregator.ingest(sample_observation)
with patch.object(
aggregator._tracker_engine, "detect_tracker", wraps=aggregator._tracker_engine.detect_tracker
) as mock_detect:
device = aggregator.ingest(sample_observation)
assert mock_detect.call_count == 0, (
"detect_tracker should not be called when the device payload fingerprint is unchanged"
)
# Tracker fields from first ingest must be preserved on skip
assert device.payload_fingerprint_id is not None
def test_tracker_detection_runs_when_payload_changes(self, aggregator, sample_observation):
"""detect_tracker must be called again when the device's manufacturer data changes."""
aggregator.ingest(sample_observation)
changed = dataclasses.replace(
sample_observation,
manufacturer_data=bytes([0xDE, 0xAD, 0xBE, 0xEF]),
)
with patch.object(
aggregator._tracker_engine, "detect_tracker", wraps=aggregator._tracker_engine.detect_tracker
) as mock_detect:
aggregator.ingest(changed)
assert mock_detect.call_count == 1, "detect_tracker must be called when manufacturer data changes"