"""Tests for SwitchbotEncryptedDevice base class.""" from __future__ import annotations import asyncio from typing import Any from unittest.mock import AsyncMock, patch import pytest from bleak.exc import BleakDBusError from switchbot import SwitchbotModel from switchbot.devices.device import ( SwitchbotEncryptedDevice, ) from .test_adv_parser import generate_ble_device class MockEncryptedDevice(SwitchbotEncryptedDevice): """Mock encrypted device for testing.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.update_count: int = 0 async def update(self, interface: int | None = None) -> None: self.update_count += 1 def create_encrypted_device( model: SwitchbotModel = SwitchbotModel.LOCK, ) -> MockEncryptedDevice: """Create an encrypted device for testing.""" ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device") return MockEncryptedDevice( ble_device, "01", "0123456789abcdef0123456789abcdef", model=model ) @pytest.mark.asyncio async def test_encrypted_device_init() -> None: """Test encrypted device initialization.""" device = create_encrypted_device() assert device._key_id == "01" assert device._encryption_key == bytearray.fromhex( "0123456789abcdef0123456789abcdef" ) assert device._iv is None assert device._cipher is None @pytest.mark.asyncio async def test_encrypted_device_init_validation() -> None: """Test encrypted device initialization with invalid parameters.""" ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device") # Test empty key_id with pytest.raises(ValueError, match="key_id is missing"): MockEncryptedDevice( ble_device, "", "0123456789abcdef0123456789abcdef", SwitchbotModel.LOCK ) # Test invalid key_id length with pytest.raises(ValueError, match="key_id is invalid"): MockEncryptedDevice( ble_device, "1", "0123456789abcdef0123456789abcdef", SwitchbotModel.LOCK ) # Test empty encryption_key with pytest.raises(ValueError, match="encryption_key is missing"): MockEncryptedDevice(ble_device, "01", "", SwitchbotModel.LOCK) # Test invalid encryption_key length with pytest.raises(ValueError, match="encryption_key is invalid"): MockEncryptedDevice(ble_device, "01", "0123456789abcdef", SwitchbotModel.LOCK) @pytest.mark.asyncio async def test_send_command_unencrypted() -> None: """Test sending unencrypted command.""" device = create_encrypted_device() with patch.object(device, "_send_command_locked_with_retry") as mock_send: mock_send.return_value = b"\x01\x00\x00\x00" result = await device._send_command("570200", encrypt=False) assert result == b"\x01\x00\x00\x00" mock_send.assert_called_once() # Verify the key was padded with zeros for unencrypted command call_args = mock_send.call_args[0] assert call_args[0] == "570000000200" # Original key with zeros inserted @pytest.mark.asyncio async def test_send_command_encrypted_success() -> None: """Test successful encrypted command.""" device = create_encrypted_device() # Mock the connection and command execution with ( patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_decrypt") as mock_decrypt, ): mock_decrypt.return_value = b"decrypted_response" # First call is for IV initialization, second is for the actual command mock_send.side_effect = [ b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0", # IV response (16 bytes) b"\x01\x00\x00\x00encrypted_response", # Command response ] result = await device._send_command("570200", encrypt=True) assert result is not None assert mock_send.call_count == 2 # Verify IV was initialized assert device._iv is not None @pytest.mark.asyncio async def test_send_command_iv_already_initialized() -> None: """Test sending encrypted command when IV is already initialized.""" device = create_encrypted_device() # Pre-set the IV device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" with ( patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = ( "656e637279707465645f64617461" # "encrypted_data" in hex ) mock_decrypt.return_value = b"decrypted_response" mock_send.return_value = b"\x01\x00\x00\x00encrypted_response" result = await device._send_command("570200", encrypt=True) assert result == b"\x01decrypted_response" # Should only call once since IV is already initialized mock_send.assert_called_once() mock_encrypt.assert_called_once() mock_decrypt.assert_called_once() @pytest.mark.asyncio async def test_iv_race_condition_during_disconnect() -> None: """Test that commands during disconnect are handled properly.""" device = create_encrypted_device() # Pre-set the IV device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78" # Mock the connection mock_client = AsyncMock() mock_client.is_connected = True device._client = mock_client async def simulate_disconnect() -> None: """Simulate disconnect happening during command execution.""" await asyncio.sleep(0.01) # Small delay await device._execute_disconnect() with ( patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_ensure_connected"), patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex mock_decrypt.return_value = b"response" mock_send.return_value = b"\x01\x00\x00\x00response" # Start command and disconnect concurrently command_task = asyncio.create_task(device._send_command("570200")) disconnect_task = asyncio.create_task(simulate_disconnect()) # Both should complete without error result, _ = await asyncio.gather( command_task, disconnect_task, return_exceptions=True ) # Command should have completed successfully assert isinstance(result, bytes) or result is None # IV should be cleared after disconnect assert device._iv is None @pytest.mark.asyncio async def test_ensure_encryption_initialized_with_lock_held() -> None: """Test that _ensure_encryption_initialized properly handles the operation lock.""" device = create_encrypted_device() # Acquire the operation lock async with device._operation_lock: with patch.object(device, "_send_command_locked_with_retry") as mock_send: mock_send.return_value = b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" result = await device._ensure_encryption_initialized() assert result is True assert ( device._iv == b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" ) assert device._cipher is None # Should be reset when IV changes @pytest.mark.asyncio async def test_ensure_encryption_initialized_failure() -> None: """Test _ensure_encryption_initialized when IV initialization fails.""" device = create_encrypted_device() async with device._operation_lock: with patch.object(device, "_send_command_locked_with_retry") as mock_send: # Return failure response mock_send.return_value = b"\x00" result = await device._ensure_encryption_initialized() assert result is False assert device._iv is None @pytest.mark.asyncio async def test_encrypt_decrypt_with_valid_iv() -> None: """Test encryption and decryption with valid IV.""" device = create_encrypted_device() device._iv = b"\x00" * 16 # Use zeros for predictable test # Test encryption encrypted = device._encrypt("48656c6c6f") # "Hello" in hex assert isinstance(encrypted, str) assert len(encrypted) > 0 # Test decryption decrypted = device._decrypt(bytearray.fromhex(encrypted)) assert decrypted.hex() == "48656c6c6f" @pytest.mark.asyncio async def test_encrypt_with_none_iv() -> None: """Test that encryption raises error when IV is None.""" device = create_encrypted_device() device._iv = None with pytest.raises(RuntimeError, match="Cannot encrypt: IV is None"): device._encrypt("48656c6c6f") @pytest.mark.asyncio async def test_decrypt_with_none_iv() -> None: """Test that decryption raises error when IV is None.""" device = create_encrypted_device() device._iv = None with pytest.raises(RuntimeError, match="Cannot decrypt: IV is None"): device._decrypt(bytearray.fromhex("48656c6c6f")) @pytest.mark.asyncio async def test_get_cipher_with_none_iv() -> None: """Test that _get_cipher raises error when IV is None.""" device = create_encrypted_device() device._iv = None with pytest.raises(RuntimeError, match="Cannot create cipher: IV is None"): device._get_cipher() @pytest.mark.asyncio async def test_execute_disconnect_clears_encryption_state() -> None: """Test that disconnect properly clears encryption state.""" device = create_encrypted_device() device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" device._cipher = None # type: ignore[assignment] # Mock client mock_client = AsyncMock() device._client = mock_client with patch.object(device, "_execute_disconnect_with_lock") as mock_disconnect: await device._execute_disconnect() assert device._iv is None assert device._cipher is None mock_disconnect.assert_called_once() @pytest.mark.asyncio async def test_concurrent_commands_with_same_device() -> None: """Test multiple concurrent commands on the same device.""" device = create_encrypted_device() # Pre-initialize IV (16 bytes for AES CTR mode) device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" with ( patch.object(device, "_send_command_locked_with_retry") as mock_send, patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex mock_decrypt.return_value = b"response" mock_send.return_value = b"\x01\x00\x00\x00data" # Send multiple commands concurrently tasks = [ device._send_command("570200"), device._send_command("570201"), device._send_command("570202"), ] results = await asyncio.gather(*tasks) # All commands should succeed assert all(result == b"\x01response" for result in results) assert mock_send.call_count == 3 @pytest.mark.asyncio async def test_command_retry_with_encryption() -> None: """Test command retry logic with encrypted commands.""" device = create_encrypted_device() device._retry_count = 2 # Pre-initialize IV (16 bytes for AES CTR mode) device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" with ( patch.object(device, "_send_command_locked") as mock_send_locked, patch.object(device, "_ensure_connected"), patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex mock_decrypt.return_value = b"response" # First attempt fails, second succeeds mock_send_locked.side_effect = [ BleakDBusError("org.bluez.Error", []), b"\x01\x00\x00\x00data", ] result = await device._send_command("570200") assert result == b"\x01response" assert mock_send_locked.call_count == 2 @pytest.mark.asyncio async def test_empty_data_encryption_decryption() -> None: """Test encryption/decryption of empty data.""" device = create_encrypted_device() device._iv = b"\x00" * 16 # Test empty encryption encrypted = device._encrypt("") assert encrypted == "" # Test empty decryption decrypted = device._decrypt(bytearray()) assert decrypted == b""