"""Tests for SwitchbotEncryptedDevice base class.""" from __future__ import annotations import asyncio from typing import Any from unittest.mock import AsyncMock, patch import pytest from bleak.exc import BleakDBusError from switchbot import SwitchbotModel from switchbot.devices.device import AESMode, SwitchbotEncryptedDevice from .test_adv_parser import generate_ble_device class MockEncryptedDevice(SwitchbotEncryptedDevice): """Mock encrypted device for testing.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.update_count: int = 0 async def update(self, interface: int | None = None) -> None: self.update_count += 1 def create_encrypted_device( model: SwitchbotModel = SwitchbotModel.LOCK, ) -> MockEncryptedDevice: """Create an encrypted device for testing.""" ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device") return MockEncryptedDevice( ble_device, "01", "0123456789abcdef0123456789abcdef", model=model ) @pytest.mark.asyncio async def test_encrypted_device_init() -> None: """Test encrypted device initialization.""" device = create_encrypted_device() assert device._key_id == "01" assert device._encryption_key == bytearray.fromhex( "0123456789abcdef0123456789abcdef" ) assert device._iv is None assert device._cipher is None @pytest.mark.asyncio async def test_encrypted_device_init_validation() -> None: """Test encrypted device initialization with invalid parameters.""" ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device") # Test empty key_id with pytest.raises(ValueError, match="key_id is missing"): MockEncryptedDevice( ble_device, "", "0123456789abcdef0123456789abcdef", SwitchbotModel.LOCK ) # Test invalid key_id length with pytest.raises(ValueError, match="key_id is invalid"): MockEncryptedDevice( ble_device, "1", "0123456789abcdef0123456789abcdef", SwitchbotModel.LOCK ) # Test empty encryption_key with pytest.raises(ValueError, match="encryption_key is missing"): MockEncryptedDevice(ble_device, "01", "", SwitchbotModel.LOCK) # Test invalid encryption_key length with pytest.raises(ValueError, match="encryption_key is invalid"): MockEncryptedDevice(ble_device, "01", "0123456789abcdef", SwitchbotModel.LOCK) @pytest.mark.asyncio async def test_send_command_unencrypted() -> None: """Test sending unencrypted command.""" device = create_encrypted_device() with patch.object(device, "_send_command_locked_with_retry") as mock_send: mock_send.return_value = b"\x01\x00\x00\x00" result = await device._send_command("570200", encrypt=False) assert result == b"\x01\x00\x00\x00" mock_send.assert_called_once() # Verify the key was padded with zeros for unencrypted command call_args = mock_send.call_args[0] assert call_args[0] == "570000000200" # Original key with zeros inserted @pytest.mark.asyncio async def test_send_command_encrypted_success() -> None: """Test successful encrypted command.""" device = create_encrypted_device() # Mock the connection and command execution with ( patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_decrypt") as mock_decrypt, ): mock_decrypt.return_value = b"decrypted_response" # First call is for IV initialization, second is for the actual command mock_send.side_effect = [ b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0", # IV response (16 bytes) b"\x01\x00\x00\x00encrypted_response", # Command response ] result = await device._send_command("570200", encrypt=True) assert result is not None assert mock_send.call_count == 2 # Verify IV was initialized assert device._iv is not None @pytest.mark.asyncio async def test_send_command_iv_already_initialized() -> None: """Test sending encrypted command when IV is already initialized.""" device = create_encrypted_device() # Pre-set the IV device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" with ( patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = ( "656e637279707465645f64617461", # "encrypted_data" in hex "abcd", ) mock_decrypt.return_value = b"decrypted_response" mock_send.return_value = b"\x01\x00\x00\x00encrypted_response" result = await device._send_command("570200", encrypt=True) assert result == b"\x01decrypted_response" # Should only call once since IV is already initialized mock_send.assert_called_once() mock_encrypt.assert_called_once() mock_decrypt.assert_called_once() @pytest.mark.asyncio async def test_iv_race_condition_during_disconnect() -> None: """Test that commands during disconnect are handled properly.""" device = create_encrypted_device() # Pre-set the IV device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78" # Mock the connection mock_client = AsyncMock() mock_client.is_connected = True device._client = mock_client async def simulate_disconnect() -> None: """Simulate disconnect happening during command execution.""" await asyncio.sleep(0.01) # Small delay await device._execute_disconnect() with ( patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_ensure_connected"), patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = ("656e63727970746564", "abcd") mock_decrypt.return_value = b"response" mock_send.return_value = b"\x01\x00\x00\x00response" # Start command and disconnect concurrently command_task = asyncio.create_task(device._send_command("570200")) disconnect_task = asyncio.create_task(simulate_disconnect()) # Both should complete without error result, _ = await asyncio.gather( command_task, disconnect_task, return_exceptions=True ) # Command should have completed successfully assert isinstance(result, bytes) or result is None # IV should be cleared after disconnect assert device._iv is None @pytest.mark.asyncio async def test_ensure_encryption_initialized_with_lock_held() -> None: """Test that _ensure_encryption_initialized properly handles the operation lock.""" device = create_encrypted_device() # Acquire the operation lock async with device._operation_lock: with patch.object(device, "_send_command_locked_with_retry") as mock_send: mock_send.return_value = b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" result = await device._ensure_encryption_initialized() assert result is True assert ( device._iv == b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" ) assert device._cipher is None # Should be reset when IV changes @pytest.mark.asyncio async def test_ensure_encryption_initialized_sets_gcm_mode() -> None: """Test that GCM mode is detected from device response.""" device = create_encrypted_device() gcm_iv = b"\x01" * 12 response = b"\x01\x00\x01\x00" + gcm_iv + b"\x00\x00\x00\x00" async with device._operation_lock: with patch.object(device, "_send_command_locked_with_retry") as mock_send: mock_send.return_value = response result = await device._ensure_encryption_initialized() assert result is True assert device._encryption_mode == AESMode.GCM assert device._iv == gcm_iv @pytest.mark.asyncio async def test_ensure_encryption_initialized_invalid_iv_length_gcm() -> None: """Test that invalid IV length for GCM mode returns False.""" device = create_encrypted_device() # GCM expects 12 bytes IV, but response has wrong length (only 8 bytes after trimming) response = b"\x01\x00\x01\x00" + b"\x01" * 8 + b"\x00\x00\x00\x00" async with device._operation_lock: with patch.object(device, "_send_command_locked_with_retry") as mock_send: mock_send.return_value = response result = await device._ensure_encryption_initialized() assert result is False assert device._iv is None @pytest.mark.asyncio async def test_ensure_encryption_initialized_invalid_iv_length_ctr() -> None: """Test that invalid IV length for CTR mode returns False.""" device = create_encrypted_device() # CTR expects 16 bytes IV, but response has only 8 bytes response = b"\x01\x00\x00\x00" + b"\x01" * 8 async with device._operation_lock: with patch.object(device, "_send_command_locked_with_retry") as mock_send: mock_send.return_value = response result = await device._ensure_encryption_initialized() assert result is False assert device._iv is None @pytest.mark.asyncio async def test_device_with_gcm_mode() -> None: """Test that device initializes correctly in GCM mode and increments GCM IV.""" device = create_encrypted_device() device._encryption_mode = AESMode.GCM device._iv = b"\x01" * 12 with ( patch.object(device, "_ensure_encryption_initialized") as mock_ensure, patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_decrypt") as mock_decrypt, patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_increment_gcm_iv") as mock_inc_iv, ): mock_ensure.return_value = True mock_encrypt.return_value = ("10203040", "abcd") mock_send.return_value = b"\x01\x00\x00\x00\x10\x20\x30\x40" mock_decrypt.return_value = b"\x10\x20\x30\x40" await device._send_command("570200") mock_inc_iv.assert_called_once() @pytest.mark.asyncio async def test_resolve_encryption_mode_invalid() -> None: """Test that invalid mode byte raises error.""" device = create_encrypted_device() with pytest.raises(ValueError, match="Unsupported encryption mode"): device._resolve_encryption_mode(2) @pytest.mark.asyncio async def test_resolve_encryption_mode_missing() -> None: """Test that missing mode byte raises error.""" device = create_encrypted_device() with pytest.raises(ValueError, match="Encryption mode byte is missing"): device._resolve_encryption_mode(None) @pytest.mark.asyncio async def test_resolve_encryption_mode_conflict() -> None: """Test that conflicting encryption modes raise error.""" device = create_encrypted_device() device._encryption_mode = AESMode.CTR with pytest.raises( ValueError, match="Conflicting encryption modes detected: CTR vs GCM", ): device._resolve_encryption_mode(1) @pytest.mark.asyncio async def test_increment_gcm_iv() -> None: """Test GCM IV increment logic.""" device = create_encrypted_device() device._encryption_mode = AESMode.GCM device._iv = b"\x00" * 11 + b"\x01" device._increment_gcm_iv() assert device._iv == b"\x00" * 11 + b"\x02" assert device._cipher is None @pytest.mark.asyncio @pytest.mark.parametrize( ("initial_iv", "expected_exception", "expected_message"), [ (None, RuntimeError, "Cannot increment GCM IV: IV is None"), ( b"\x00" * 10, RuntimeError, "Cannot increment GCM IV: IV length is not 12 bytes", ), ], ) async def test_increment_gcm_iv_invalid( initial_iv, expected_exception, expected_message ) -> None: """Test GCM IV increment with invalid IV states.""" device = create_encrypted_device() device._encryption_mode = AESMode.GCM device._iv = initial_iv with pytest.raises(expected_exception, match=expected_message): device._increment_gcm_iv() @pytest.mark.asyncio async def test_gcm_encrypt_decrypt_without_finalize() -> None: """Test GCM encrypt/decrypt works without finalize in decrypt.""" device = create_encrypted_device() device._encryption_mode = AESMode.GCM device._iv = b"\x10" * 12 ciphertext_hex, _ = device._encrypt("48656c6c6f") decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex)) assert decrypted.hex() == "48656c6c6f" @pytest.mark.asyncio async def test_ensure_encryption_initialized_failure() -> None: """Test _ensure_encryption_initialized when IV initialization fails.""" device = create_encrypted_device() async with device._operation_lock: with patch.object(device, "_send_command_locked_with_retry") as mock_send: # Return failure response mock_send.return_value = b"\x00" result = await device._ensure_encryption_initialized() assert result is False assert device._iv is None @pytest.mark.asyncio async def test_encrypt_decrypt_with_valid_iv() -> None: """Test encryption and decryption with valid IV.""" device = create_encrypted_device() device._iv = b"\x00" * 16 # Use zeros for predictable test # Test encryption ciphertext_hex, header_hex = device._encrypt("48656c6c6f") # "Hello" in hex assert isinstance(ciphertext_hex, str) assert isinstance(header_hex, str) assert len(ciphertext_hex) > 0 # Test decryption decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex)) assert decrypted.hex() == "48656c6c6f" @pytest.mark.asyncio async def test_encrypt_with_none_iv() -> None: """Test that encryption raises error when IV is None.""" device = create_encrypted_device() device._iv = None with pytest.raises(RuntimeError, match="Cannot encrypt: IV is None"): device._encrypt("48656c6c6f") @pytest.mark.asyncio async def test_decrypt_with_none_iv() -> None: """Test that decryption raises error when IV is None.""" device = create_encrypted_device() device._iv = None with pytest.raises(RuntimeError, match="Cannot decrypt: IV is None"): device._decrypt(bytearray.fromhex("48656c6c6f")) @pytest.mark.asyncio async def test_get_cipher_with_none_iv() -> None: """Test that _get_cipher raises error when IV is None.""" device = create_encrypted_device() device._iv = None with pytest.raises(RuntimeError, match="Cannot create cipher: IV is None"): device._get_cipher() @pytest.mark.asyncio async def test_execute_disconnect_clears_encryption_state() -> None: """Test that disconnect properly clears encryption state.""" device = create_encrypted_device() device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" device._cipher = None # type: ignore[assignment] device._encryption_mode = AESMode.CTR # Mock client mock_client = AsyncMock() device._client = mock_client with patch.object(device, "_execute_disconnect_with_lock") as mock_disconnect: await device._execute_disconnect() assert device._iv is None assert device._cipher is None assert device._encryption_mode is None mock_disconnect.assert_called_once() @pytest.mark.asyncio async def test_concurrent_commands_with_same_device() -> None: """Test multiple concurrent commands on the same device.""" device = create_encrypted_device() # Pre-initialize IV (16 bytes for AES CTR mode) device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" with ( patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = ("656e63727970746564", "abcd") mock_decrypt.return_value = b"response" mock_send.return_value = b"\x01\x00\x00\x00data" # Send multiple commands concurrently tasks = [ device._send_command("570200"), device._send_command("570201"), device._send_command("570202"), ] results = await asyncio.gather(*tasks) # All commands should succeed assert all(result == b"\x01response" for result in results) assert mock_send.call_count == 3 @pytest.mark.asyncio async def test_command_retry_with_encryption() -> None: """Test command retry logic with encrypted commands.""" device = create_encrypted_device() device._retry_count = 2 # Pre-initialize IV (16 bytes for AES CTR mode) device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" with ( patch.object(device, "_send_command_locked") as mock_send_locked, patch.object(device, "_ensure_connected"), patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = ("656e63727970746564", "abcd") mock_decrypt.return_value = b"response" # First attempt fails, second succeeds mock_send_locked.side_effect = [ BleakDBusError("org.bluez.Error", []), b"\x01\x00\x00\x00data", ] result = await device._send_command("570200") assert result == b"\x01response" assert mock_send_locked.call_count == 2 @pytest.mark.asyncio async def test_empty_data_encryption_decryption() -> None: """Test encryption/decryption of empty data.""" device = create_encrypted_device() device._iv = b"\x00" * 16 # Test empty encryption encrypted = device._encrypt("") assert encrypted == ("", "") # Test empty decryption decrypted = device._decrypt(bytearray()) assert decrypted == b"" @pytest.mark.asyncio async def test_decrypt_with_none_iv_during_disconnect() -> None: """Test that decryption returns empty bytes when IV is None during expected disconnect.""" device = create_encrypted_device() # Simulate disconnection in progress device._expected_disconnect = True device._iv = None # Should return empty bytes instead of raising result = device._decrypt(bytearray(b"encrypted_data")) assert result == b"" # Verify it still raises when not disconnecting device._expected_disconnect = False with pytest.raises(RuntimeError, match="Cannot decrypt: IV is None"): device._decrypt(bytearray(b"encrypted_data"))