|
|
@@ -10,9 +10,7 @@ import pytest
|
|
|
from bleak.exc import BleakDBusError
|
|
|
|
|
|
from switchbot import SwitchbotModel
|
|
|
-from switchbot.devices.device import (
|
|
|
- SwitchbotEncryptedDevice,
|
|
|
-)
|
|
|
+from switchbot.devices.device import AESMode, SwitchbotEncryptedDevice
|
|
|
|
|
|
from .test_adv_parser import generate_ble_device
|
|
|
|
|
|
@@ -133,7 +131,8 @@ async def test_send_command_iv_already_initialized() -> None:
|
|
|
patch.object(device, "_decrypt") as mock_decrypt,
|
|
|
):
|
|
|
mock_encrypt.return_value = (
|
|
|
- "656e637279707465645f64617461" # "encrypted_data" in hex
|
|
|
+ "656e637279707465645f64617461", # "encrypted_data" in hex
|
|
|
+ "abcd",
|
|
|
)
|
|
|
mock_decrypt.return_value = b"decrypted_response"
|
|
|
mock_send.return_value = b"\x01\x00\x00\x00encrypted_response"
|
|
|
@@ -171,7 +170,7 @@ async def test_iv_race_condition_during_disconnect() -> None:
|
|
|
patch.object(device, "_encrypt") as mock_encrypt,
|
|
|
patch.object(device, "_decrypt") as mock_decrypt,
|
|
|
):
|
|
|
- mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex
|
|
|
+ mock_encrypt.return_value = ("656e63727970746564", "abcd")
|
|
|
mock_decrypt.return_value = b"response"
|
|
|
mock_send.return_value = b"\x01\x00\x00\x00response"
|
|
|
|
|
|
@@ -210,6 +209,166 @@ async def test_ensure_encryption_initialized_with_lock_held() -> None:
|
|
|
assert device._cipher is None # Should be reset when IV changes
|
|
|
|
|
|
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_ensure_encryption_initialized_sets_gcm_mode() -> None:
|
|
|
+ """Test that GCM mode is detected from device response."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+
|
|
|
+ gcm_iv = b"\x01" * 12
|
|
|
+ response = b"\x01\x00\x01\x00" + gcm_iv + b"\x00\x00\x00\x00"
|
|
|
+
|
|
|
+ async with device._operation_lock:
|
|
|
+ with patch.object(device, "_send_command_locked_with_retry") as mock_send:
|
|
|
+ mock_send.return_value = response
|
|
|
+
|
|
|
+ result = await device._ensure_encryption_initialized()
|
|
|
+
|
|
|
+ assert result is True
|
|
|
+ assert device._encryption_mode == AESMode.GCM
|
|
|
+ assert device._iv == gcm_iv
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_ensure_encryption_initialized_invalid_iv_length_gcm() -> None:
|
|
|
+ """Test that invalid IV length for GCM mode returns False."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+
|
|
|
+ # GCM expects 12 bytes IV, but response has wrong length (only 8 bytes after trimming)
|
|
|
+ response = b"\x01\x00\x01\x00" + b"\x01" * 8 + b"\x00\x00\x00\x00"
|
|
|
+
|
|
|
+ async with device._operation_lock:
|
|
|
+ with patch.object(device, "_send_command_locked_with_retry") as mock_send:
|
|
|
+ mock_send.return_value = response
|
|
|
+
|
|
|
+ result = await device._ensure_encryption_initialized()
|
|
|
+
|
|
|
+ assert result is False
|
|
|
+ assert device._iv is None
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_ensure_encryption_initialized_invalid_iv_length_ctr() -> None:
|
|
|
+ """Test that invalid IV length for CTR mode returns False."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+
|
|
|
+ # CTR expects 16 bytes IV, but response has only 8 bytes
|
|
|
+ response = b"\x01\x00\x00\x00" + b"\x01" * 8
|
|
|
+
|
|
|
+ async with device._operation_lock:
|
|
|
+ with patch.object(device, "_send_command_locked_with_retry") as mock_send:
|
|
|
+ mock_send.return_value = response
|
|
|
+
|
|
|
+ result = await device._ensure_encryption_initialized()
|
|
|
+
|
|
|
+ assert result is False
|
|
|
+ assert device._iv is None
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_device_with_gcm_mode() -> None:
|
|
|
+ """Test that device initializes correctly in GCM mode and increments GCM IV."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+ device._encryption_mode = AESMode.GCM
|
|
|
+ device._iv = b"\x01" * 12
|
|
|
+
|
|
|
+ with (
|
|
|
+ patch.object(device, "_ensure_encryption_initialized") as mock_ensure,
|
|
|
+ patch.object(device, "_send_command_locked_with_retry") as mock_send,
|
|
|
+ patch.object(device, "_decrypt") as mock_decrypt,
|
|
|
+ patch.object(device, "_encrypt") as mock_encrypt,
|
|
|
+ patch.object(device, "_increment_gcm_iv") as mock_inc_iv,
|
|
|
+ ):
|
|
|
+ mock_ensure.return_value = True
|
|
|
+ mock_encrypt.return_value = ("10203040", "abcd")
|
|
|
+ mock_send.return_value = b"\x01\x00\x00\x00\x10\x20\x30\x40"
|
|
|
+ mock_decrypt.return_value = b"\x10\x20\x30\x40"
|
|
|
+
|
|
|
+ await device._send_command("570200")
|
|
|
+
|
|
|
+ mock_inc_iv.assert_called_once()
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_resolve_encryption_mode_invalid() -> None:
|
|
|
+ """Test that invalid mode byte raises error."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+
|
|
|
+ with pytest.raises(ValueError, match="Unsupported encryption mode"):
|
|
|
+ device._resolve_encryption_mode(2)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_resolve_encryption_mode_missing() -> None:
|
|
|
+ """Test that missing mode byte raises error."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+
|
|
|
+ with pytest.raises(ValueError, match="Encryption mode byte is missing"):
|
|
|
+ device._resolve_encryption_mode(None)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_resolve_encryption_mode_conflict() -> None:
|
|
|
+ """Test that conflicting encryption modes raise error."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+ device._encryption_mode = AESMode.CTR
|
|
|
+
|
|
|
+ with pytest.raises(
|
|
|
+ ValueError,
|
|
|
+ match="Conflicting encryption modes detected: CTR vs GCM",
|
|
|
+ ):
|
|
|
+ device._resolve_encryption_mode(1)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_increment_gcm_iv() -> None:
|
|
|
+ """Test GCM IV increment logic."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+ device._encryption_mode = AESMode.GCM
|
|
|
+ device._iv = b"\x00" * 11 + b"\x01"
|
|
|
+
|
|
|
+ device._increment_gcm_iv()
|
|
|
+
|
|
|
+ assert device._iv == b"\x00" * 11 + b"\x02"
|
|
|
+ assert device._cipher is None
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ ("initial_iv", "expected_exception", "expected_message"),
|
|
|
+ [
|
|
|
+ (None, RuntimeError, "Cannot increment GCM IV: IV is None"),
|
|
|
+ (
|
|
|
+ b"\x00" * 10,
|
|
|
+ RuntimeError,
|
|
|
+ "Cannot increment GCM IV: IV length is not 12 bytes",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_increment_gcm_iv_invalid(
|
|
|
+ initial_iv, expected_exception, expected_message
|
|
|
+) -> None:
|
|
|
+ """Test GCM IV increment with invalid IV states."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+ device._encryption_mode = AESMode.GCM
|
|
|
+ device._iv = initial_iv
|
|
|
+
|
|
|
+ with pytest.raises(expected_exception, match=expected_message):
|
|
|
+ device._increment_gcm_iv()
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_gcm_encrypt_decrypt_without_finalize() -> None:
|
|
|
+ """Test GCM encrypt/decrypt works without finalize in decrypt."""
|
|
|
+ device = create_encrypted_device()
|
|
|
+ device._encryption_mode = AESMode.GCM
|
|
|
+ device._iv = b"\x10" * 12
|
|
|
+
|
|
|
+ ciphertext_hex, _ = device._encrypt("48656c6c6f")
|
|
|
+ decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex))
|
|
|
+
|
|
|
+ assert decrypted.hex() == "48656c6c6f"
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_ensure_encryption_initialized_failure() -> None:
|
|
|
"""Test _ensure_encryption_initialized when IV initialization fails."""
|
|
|
@@ -233,12 +392,13 @@ async def test_encrypt_decrypt_with_valid_iv() -> None:
|
|
|
device._iv = b"\x00" * 16 # Use zeros for predictable test
|
|
|
|
|
|
# Test encryption
|
|
|
- encrypted = device._encrypt("48656c6c6f") # "Hello" in hex
|
|
|
- assert isinstance(encrypted, str)
|
|
|
- assert len(encrypted) > 0
|
|
|
+ ciphertext_hex, header_hex = device._encrypt("48656c6c6f") # "Hello" in hex
|
|
|
+ assert isinstance(ciphertext_hex, str)
|
|
|
+ assert isinstance(header_hex, str)
|
|
|
+ assert len(ciphertext_hex) > 0
|
|
|
|
|
|
# Test decryption
|
|
|
- decrypted = device._decrypt(bytearray.fromhex(encrypted))
|
|
|
+ decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex))
|
|
|
assert decrypted.hex() == "48656c6c6f"
|
|
|
|
|
|
|
|
|
@@ -278,6 +438,7 @@ async def test_execute_disconnect_clears_encryption_state() -> None:
|
|
|
device = create_encrypted_device()
|
|
|
device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
|
|
|
device._cipher = None # type: ignore[assignment]
|
|
|
+ device._encryption_mode = AESMode.CTR
|
|
|
|
|
|
# Mock client
|
|
|
mock_client = AsyncMock()
|
|
|
@@ -288,6 +449,7 @@ async def test_execute_disconnect_clears_encryption_state() -> None:
|
|
|
|
|
|
assert device._iv is None
|
|
|
assert device._cipher is None
|
|
|
+ assert device._encryption_mode is None
|
|
|
mock_disconnect.assert_called_once()
|
|
|
|
|
|
|
|
|
@@ -304,7 +466,7 @@ async def test_concurrent_commands_with_same_device() -> None:
|
|
|
patch.object(device, "_encrypt") as mock_encrypt,
|
|
|
patch.object(device, "_decrypt") as mock_decrypt,
|
|
|
):
|
|
|
- mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex
|
|
|
+ mock_encrypt.return_value = ("656e63727970746564", "abcd")
|
|
|
mock_decrypt.return_value = b"response"
|
|
|
mock_send.return_value = b"\x01\x00\x00\x00data"
|
|
|
|
|
|
@@ -337,7 +499,7 @@ async def test_command_retry_with_encryption() -> None:
|
|
|
patch.object(device, "_encrypt") as mock_encrypt,
|
|
|
patch.object(device, "_decrypt") as mock_decrypt,
|
|
|
):
|
|
|
- mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex
|
|
|
+ mock_encrypt.return_value = ("656e63727970746564", "abcd")
|
|
|
mock_decrypt.return_value = b"response"
|
|
|
|
|
|
# First attempt fails, second succeeds
|
|
|
@@ -360,7 +522,7 @@ async def test_empty_data_encryption_decryption() -> None:
|
|
|
|
|
|
# Test empty encryption
|
|
|
encrypted = device._encrypt("")
|
|
|
- assert encrypted == ""
|
|
|
+ assert encrypted == ("", "")
|
|
|
|
|
|
# Test empty decryption
|
|
|
decrypted = device._decrypt(bytearray())
|