test_encrypted_device.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. """Tests for SwitchbotEncryptedDevice base class."""
  2. from __future__ import annotations
  3. import asyncio
  4. from typing import Any
  5. from unittest.mock import AsyncMock, patch
  6. import pytest
  7. from bleak.exc import BleakDBusError
  8. from switchbot import SwitchbotModel
  9. from switchbot.devices.device import (
  10. SwitchbotEncryptedDevice,
  11. )
  12. from .test_adv_parser import generate_ble_device
  13. class MockEncryptedDevice(SwitchbotEncryptedDevice):
  14. """Mock encrypted device for testing."""
  15. def __init__(self, *args: Any, **kwargs: Any) -> None:
  16. super().__init__(*args, **kwargs)
  17. self.update_count: int = 0
  18. async def update(self, interface: int | None = None) -> None:
  19. self.update_count += 1
  20. def create_encrypted_device(
  21. model: SwitchbotModel = SwitchbotModel.LOCK,
  22. ) -> MockEncryptedDevice:
  23. """Create an encrypted device for testing."""
  24. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  25. return MockEncryptedDevice(
  26. ble_device, "01", "0123456789abcdef0123456789abcdef", model=model
  27. )
  28. @pytest.mark.asyncio
  29. async def test_encrypted_device_init() -> None:
  30. """Test encrypted device initialization."""
  31. device = create_encrypted_device()
  32. assert device._key_id == "01"
  33. assert device._encryption_key == bytearray.fromhex(
  34. "0123456789abcdef0123456789abcdef"
  35. )
  36. assert device._iv is None
  37. assert device._cipher is None
  38. @pytest.mark.asyncio
  39. async def test_encrypted_device_init_validation() -> None:
  40. """Test encrypted device initialization with invalid parameters."""
  41. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  42. # Test empty key_id
  43. with pytest.raises(ValueError, match="key_id is missing"):
  44. MockEncryptedDevice(
  45. ble_device, "", "0123456789abcdef0123456789abcdef", SwitchbotModel.LOCK
  46. )
  47. # Test invalid key_id length
  48. with pytest.raises(ValueError, match="key_id is invalid"):
  49. MockEncryptedDevice(
  50. ble_device, "1", "0123456789abcdef0123456789abcdef", SwitchbotModel.LOCK
  51. )
  52. # Test empty encryption_key
  53. with pytest.raises(ValueError, match="encryption_key is missing"):
  54. MockEncryptedDevice(ble_device, "01", "", SwitchbotModel.LOCK)
  55. # Test invalid encryption_key length
  56. with pytest.raises(ValueError, match="encryption_key is invalid"):
  57. MockEncryptedDevice(ble_device, "01", "0123456789abcdef", SwitchbotModel.LOCK)
  58. @pytest.mark.asyncio
  59. async def test_send_command_unencrypted() -> None:
  60. """Test sending unencrypted command."""
  61. device = create_encrypted_device()
  62. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  63. mock_send.return_value = b"\x01\x00\x00\x00"
  64. result = await device._send_command("570200", encrypt=False)
  65. assert result == b"\x01\x00\x00\x00"
  66. mock_send.assert_called_once()
  67. # Verify the key was padded with zeros for unencrypted command
  68. call_args = mock_send.call_args[0]
  69. assert call_args[0] == "570000000200" # Original key with zeros inserted
  70. @pytest.mark.asyncio
  71. async def test_send_command_encrypted_success() -> None:
  72. """Test successful encrypted command."""
  73. device = create_encrypted_device()
  74. # Mock the connection and command execution
  75. with (
  76. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  77. patch.object(device, "_decrypt") as mock_decrypt,
  78. ):
  79. mock_decrypt.return_value = b"decrypted_response"
  80. # First call is for IV initialization, second is for the actual command
  81. mock_send.side_effect = [
  82. b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0", # IV response (16 bytes)
  83. b"\x01\x00\x00\x00encrypted_response", # Command response
  84. ]
  85. result = await device._send_command("570200", encrypt=True)
  86. assert result is not None
  87. assert mock_send.call_count == 2
  88. # Verify IV was initialized
  89. assert device._iv is not None
  90. @pytest.mark.asyncio
  91. async def test_send_command_iv_already_initialized() -> None:
  92. """Test sending encrypted command when IV is already initialized."""
  93. device = create_encrypted_device()
  94. # Pre-set the IV
  95. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  96. with (
  97. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  98. patch.object(device, "_encrypt") as mock_encrypt,
  99. patch.object(device, "_decrypt") as mock_decrypt,
  100. ):
  101. mock_encrypt.return_value = (
  102. "656e637279707465645f64617461" # "encrypted_data" in hex
  103. )
  104. mock_decrypt.return_value = b"decrypted_response"
  105. mock_send.return_value = b"\x01\x00\x00\x00encrypted_response"
  106. result = await device._send_command("570200", encrypt=True)
  107. assert result == b"\x01decrypted_response"
  108. # Should only call once since IV is already initialized
  109. mock_send.assert_called_once()
  110. mock_encrypt.assert_called_once()
  111. mock_decrypt.assert_called_once()
  112. @pytest.mark.asyncio
  113. async def test_iv_race_condition_during_disconnect() -> None:
  114. """Test that commands during disconnect are handled properly."""
  115. device = create_encrypted_device()
  116. # Pre-set the IV
  117. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78"
  118. # Mock the connection
  119. mock_client = AsyncMock()
  120. mock_client.is_connected = True
  121. device._client = mock_client
  122. async def simulate_disconnect() -> None:
  123. """Simulate disconnect happening during command execution."""
  124. await asyncio.sleep(0.01) # Small delay
  125. await device._execute_disconnect()
  126. with (
  127. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  128. patch.object(device, "_ensure_connected"),
  129. patch.object(device, "_encrypt") as mock_encrypt,
  130. patch.object(device, "_decrypt") as mock_decrypt,
  131. ):
  132. mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex
  133. mock_decrypt.return_value = b"response"
  134. mock_send.return_value = b"\x01\x00\x00\x00response"
  135. # Start command and disconnect concurrently
  136. command_task = asyncio.create_task(device._send_command("570200"))
  137. disconnect_task = asyncio.create_task(simulate_disconnect())
  138. # Both should complete without error
  139. result, _ = await asyncio.gather(
  140. command_task, disconnect_task, return_exceptions=True
  141. )
  142. # Command should have completed successfully
  143. assert isinstance(result, bytes) or result is None
  144. # IV should be cleared after disconnect
  145. assert device._iv is None
  146. @pytest.mark.asyncio
  147. async def test_ensure_encryption_initialized_with_lock_held() -> None:
  148. """Test that _ensure_encryption_initialized properly handles the operation lock."""
  149. device = create_encrypted_device()
  150. # Acquire the operation lock
  151. async with device._operation_lock:
  152. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  153. mock_send.return_value = b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  154. result = await device._ensure_encryption_initialized()
  155. assert result is True
  156. assert (
  157. device._iv
  158. == b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  159. )
  160. assert device._cipher is None # Should be reset when IV changes
  161. @pytest.mark.asyncio
  162. async def test_ensure_encryption_initialized_failure() -> None:
  163. """Test _ensure_encryption_initialized when IV initialization fails."""
  164. device = create_encrypted_device()
  165. async with device._operation_lock:
  166. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  167. # Return failure response
  168. mock_send.return_value = b"\x00"
  169. result = await device._ensure_encryption_initialized()
  170. assert result is False
  171. assert device._iv is None
  172. @pytest.mark.asyncio
  173. async def test_encrypt_decrypt_with_valid_iv() -> None:
  174. """Test encryption and decryption with valid IV."""
  175. device = create_encrypted_device()
  176. device._iv = b"\x00" * 16 # Use zeros for predictable test
  177. # Test encryption
  178. encrypted = device._encrypt("48656c6c6f") # "Hello" in hex
  179. assert isinstance(encrypted, str)
  180. assert len(encrypted) > 0
  181. # Test decryption
  182. decrypted = device._decrypt(bytearray.fromhex(encrypted))
  183. assert decrypted.hex() == "48656c6c6f"
  184. @pytest.mark.asyncio
  185. async def test_encrypt_with_none_iv() -> None:
  186. """Test that encryption raises error when IV is None."""
  187. device = create_encrypted_device()
  188. device._iv = None
  189. with pytest.raises(RuntimeError, match="Cannot encrypt: IV is None"):
  190. device._encrypt("48656c6c6f")
  191. @pytest.mark.asyncio
  192. async def test_decrypt_with_none_iv() -> None:
  193. """Test that decryption raises error when IV is None."""
  194. device = create_encrypted_device()
  195. device._iv = None
  196. with pytest.raises(RuntimeError, match="Cannot decrypt: IV is None"):
  197. device._decrypt(bytearray.fromhex("48656c6c6f"))
  198. @pytest.mark.asyncio
  199. async def test_get_cipher_with_none_iv() -> None:
  200. """Test that _get_cipher raises error when IV is None."""
  201. device = create_encrypted_device()
  202. device._iv = None
  203. with pytest.raises(RuntimeError, match="Cannot create cipher: IV is None"):
  204. device._get_cipher()
  205. @pytest.mark.asyncio
  206. async def test_execute_disconnect_clears_encryption_state() -> None:
  207. """Test that disconnect properly clears encryption state."""
  208. device = create_encrypted_device()
  209. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  210. device._cipher = None # type: ignore[assignment]
  211. # Mock client
  212. mock_client = AsyncMock()
  213. device._client = mock_client
  214. with patch.object(device, "_execute_disconnect_with_lock") as mock_disconnect:
  215. await device._execute_disconnect()
  216. assert device._iv is None
  217. assert device._cipher is None
  218. mock_disconnect.assert_called_once()
  219. @pytest.mark.asyncio
  220. async def test_concurrent_commands_with_same_device() -> None:
  221. """Test multiple concurrent commands on the same device."""
  222. device = create_encrypted_device()
  223. # Pre-initialize IV (16 bytes for AES CTR mode)
  224. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  225. with (
  226. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  227. patch.object(device, "_encrypt") as mock_encrypt,
  228. patch.object(device, "_decrypt") as mock_decrypt,
  229. ):
  230. mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex
  231. mock_decrypt.return_value = b"response"
  232. mock_send.return_value = b"\x01\x00\x00\x00data"
  233. # Send multiple commands concurrently
  234. tasks = [
  235. device._send_command("570200"),
  236. device._send_command("570201"),
  237. device._send_command("570202"),
  238. ]
  239. results = await asyncio.gather(*tasks)
  240. # All commands should succeed
  241. assert all(result == b"\x01response" for result in results)
  242. assert mock_send.call_count == 3
  243. @pytest.mark.asyncio
  244. async def test_command_retry_with_encryption() -> None:
  245. """Test command retry logic with encrypted commands."""
  246. device = create_encrypted_device()
  247. device._retry_count = 2
  248. # Pre-initialize IV (16 bytes for AES CTR mode)
  249. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  250. with (
  251. patch.object(device, "_send_command_locked") as mock_send_locked,
  252. patch.object(device, "_ensure_connected"),
  253. patch.object(device, "_encrypt") as mock_encrypt,
  254. patch.object(device, "_decrypt") as mock_decrypt,
  255. ):
  256. mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex
  257. mock_decrypt.return_value = b"response"
  258. # First attempt fails, second succeeds
  259. mock_send_locked.side_effect = [
  260. BleakDBusError("org.bluez.Error", []),
  261. b"\x01\x00\x00\x00data",
  262. ]
  263. result = await device._send_command("570200")
  264. assert result == b"\x01response"
  265. assert mock_send_locked.call_count == 2
  266. @pytest.mark.asyncio
  267. async def test_empty_data_encryption_decryption() -> None:
  268. """Test encryption/decryption of empty data."""
  269. device = create_encrypted_device()
  270. device._iv = b"\x00" * 16
  271. # Test empty encryption
  272. encrypted = device._encrypt("")
  273. assert encrypted == ""
  274. # Test empty decryption
  275. decrypted = device._decrypt(bytearray())
  276. assert decrypted == b""