test_encrypted_device.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. """Tests for SwitchbotEncryptedDevice base class."""
  2. from __future__ import annotations
  3. import asyncio
  4. from typing import Any
  5. from unittest.mock import AsyncMock, patch
  6. import pytest
  7. from bleak.exc import BleakDBusError
  8. from switchbot import SwitchbotModel
  9. from switchbot.devices.device import AESMode, SwitchbotEncryptedDevice
  10. from .test_adv_parser import generate_ble_device
  11. class MockEncryptedDevice(SwitchbotEncryptedDevice):
  12. """Mock encrypted device for testing."""
  13. def __init__(self, *args: Any, **kwargs: Any) -> None:
  14. super().__init__(*args, **kwargs)
  15. self.update_count: int = 0
  16. async def update(self, interface: int | None = None) -> None:
  17. self.update_count += 1
  18. def create_encrypted_device(
  19. model: SwitchbotModel = SwitchbotModel.LOCK,
  20. ) -> MockEncryptedDevice:
  21. """Create an encrypted device for testing."""
  22. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  23. return MockEncryptedDevice(
  24. ble_device, "01", "0123456789abcdef0123456789abcdef", model=model
  25. )
  26. @pytest.mark.asyncio
  27. async def test_encrypted_device_init() -> None:
  28. """Test encrypted device initialization."""
  29. device = create_encrypted_device()
  30. assert device._key_id == "01"
  31. assert device._encryption_key == bytearray.fromhex(
  32. "0123456789abcdef0123456789abcdef"
  33. )
  34. assert device._iv is None
  35. assert device._cipher is None
  36. @pytest.mark.asyncio
  37. async def test_encrypted_device_init_validation() -> None:
  38. """Test encrypted device initialization with invalid parameters."""
  39. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  40. # Test empty key_id
  41. with pytest.raises(ValueError, match="key_id is missing"):
  42. MockEncryptedDevice(
  43. ble_device, "", "0123456789abcdef0123456789abcdef", SwitchbotModel.LOCK
  44. )
  45. # Test invalid key_id length
  46. with pytest.raises(ValueError, match="key_id is invalid"):
  47. MockEncryptedDevice(
  48. ble_device, "1", "0123456789abcdef0123456789abcdef", SwitchbotModel.LOCK
  49. )
  50. # Test empty encryption_key
  51. with pytest.raises(ValueError, match="encryption_key is missing"):
  52. MockEncryptedDevice(ble_device, "01", "", SwitchbotModel.LOCK)
  53. # Test invalid encryption_key length
  54. with pytest.raises(ValueError, match="encryption_key is invalid"):
  55. MockEncryptedDevice(ble_device, "01", "0123456789abcdef", SwitchbotModel.LOCK)
  56. @pytest.mark.asyncio
  57. async def test_send_command_unencrypted() -> None:
  58. """Test sending unencrypted command."""
  59. device = create_encrypted_device()
  60. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  61. mock_send.return_value = b"\x01\x00\x00\x00"
  62. result = await device._send_command("570200", encrypt=False)
  63. assert result == b"\x01\x00\x00\x00"
  64. mock_send.assert_called_once()
  65. # Verify the key was padded with zeros for unencrypted command
  66. call_args = mock_send.call_args[0]
  67. assert call_args[0] == "570000000200" # Original key with zeros inserted
  68. @pytest.mark.asyncio
  69. async def test_send_command_encrypted_success() -> None:
  70. """Test successful encrypted command."""
  71. device = create_encrypted_device()
  72. # Mock the connection and command execution
  73. with (
  74. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  75. patch.object(device, "_decrypt") as mock_decrypt,
  76. ):
  77. mock_decrypt.return_value = b"decrypted_response"
  78. # First call is for IV initialization, second is for the actual command
  79. mock_send.side_effect = [
  80. b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0", # IV response (16 bytes)
  81. b"\x01\x00\x00\x00encrypted_response", # Command response
  82. ]
  83. result = await device._send_command("570200", encrypt=True)
  84. assert result is not None
  85. assert mock_send.call_count == 2
  86. # Verify IV was initialized
  87. assert device._iv is not None
  88. @pytest.mark.asyncio
  89. async def test_send_command_iv_already_initialized() -> None:
  90. """Test sending encrypted command when IV is already initialized."""
  91. device = create_encrypted_device()
  92. # Pre-set the IV
  93. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  94. with (
  95. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  96. patch.object(device, "_encrypt") as mock_encrypt,
  97. patch.object(device, "_decrypt") as mock_decrypt,
  98. ):
  99. mock_encrypt.return_value = (
  100. "656e637279707465645f64617461", # "encrypted_data" in hex
  101. "abcd",
  102. )
  103. mock_decrypt.return_value = b"decrypted_response"
  104. mock_send.return_value = b"\x01\x00\x00\x00encrypted_response"
  105. result = await device._send_command("570200", encrypt=True)
  106. assert result == b"\x01decrypted_response"
  107. # Should only call once since IV is already initialized
  108. mock_send.assert_called_once()
  109. mock_encrypt.assert_called_once()
  110. mock_decrypt.assert_called_once()
  111. @pytest.mark.asyncio
  112. async def test_iv_race_condition_during_disconnect() -> None:
  113. """Test that commands during disconnect are handled properly."""
  114. device = create_encrypted_device()
  115. # Pre-set the IV
  116. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78"
  117. # Mock the connection
  118. mock_client = AsyncMock()
  119. mock_client.is_connected = True
  120. device._client = mock_client
  121. async def simulate_disconnect() -> None:
  122. """Simulate disconnect happening during command execution."""
  123. await asyncio.sleep(0.01) # Small delay
  124. await device._execute_disconnect()
  125. with (
  126. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  127. patch.object(device, "_ensure_connected"),
  128. patch.object(device, "_encrypt") as mock_encrypt,
  129. patch.object(device, "_decrypt") as mock_decrypt,
  130. ):
  131. mock_encrypt.return_value = ("656e63727970746564", "abcd")
  132. mock_decrypt.return_value = b"response"
  133. mock_send.return_value = b"\x01\x00\x00\x00response"
  134. # Start command and disconnect concurrently
  135. command_task = asyncio.create_task(device._send_command("570200"))
  136. disconnect_task = asyncio.create_task(simulate_disconnect())
  137. # Both should complete without error
  138. result, _ = await asyncio.gather(
  139. command_task, disconnect_task, return_exceptions=True
  140. )
  141. # Command should have completed successfully
  142. assert isinstance(result, bytes) or result is None
  143. # IV should be cleared after disconnect
  144. assert device._iv is None
  145. @pytest.mark.asyncio
  146. async def test_ensure_encryption_initialized_with_lock_held() -> None:
  147. """Test that _ensure_encryption_initialized properly handles the operation lock."""
  148. device = create_encrypted_device()
  149. # Acquire the operation lock
  150. async with device._operation_lock:
  151. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  152. mock_send.return_value = b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  153. result = await device._ensure_encryption_initialized()
  154. assert result is True
  155. assert (
  156. device._iv
  157. == b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  158. )
  159. assert device._cipher is None # Should be reset when IV changes
  160. @pytest.mark.asyncio
  161. async def test_ensure_encryption_initialized_sets_gcm_mode() -> None:
  162. """Test that GCM mode is detected from device response."""
  163. device = create_encrypted_device()
  164. gcm_iv = b"\x01" * 12
  165. response = b"\x01\x00\x01\x00" + gcm_iv + b"\x00\x00\x00\x00"
  166. async with device._operation_lock:
  167. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  168. mock_send.return_value = response
  169. result = await device._ensure_encryption_initialized()
  170. assert result is True
  171. assert device._encryption_mode == AESMode.GCM
  172. assert device._iv == gcm_iv
  173. @pytest.mark.asyncio
  174. async def test_ensure_encryption_initialized_invalid_iv_length_gcm() -> None:
  175. """Test that invalid IV length for GCM mode returns False."""
  176. device = create_encrypted_device()
  177. # GCM expects 12 bytes IV, but response has wrong length (only 8 bytes after trimming)
  178. response = b"\x01\x00\x01\x00" + b"\x01" * 8 + b"\x00\x00\x00\x00"
  179. async with device._operation_lock:
  180. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  181. mock_send.return_value = response
  182. result = await device._ensure_encryption_initialized()
  183. assert result is False
  184. assert device._iv is None
  185. @pytest.mark.asyncio
  186. async def test_ensure_encryption_initialized_invalid_iv_length_ctr() -> None:
  187. """Test that invalid IV length for CTR mode returns False."""
  188. device = create_encrypted_device()
  189. # CTR expects 16 bytes IV, but response has only 8 bytes
  190. response = b"\x01\x00\x00\x00" + b"\x01" * 8
  191. async with device._operation_lock:
  192. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  193. mock_send.return_value = response
  194. result = await device._ensure_encryption_initialized()
  195. assert result is False
  196. assert device._iv is None
  197. @pytest.mark.asyncio
  198. async def test_device_with_gcm_mode() -> None:
  199. """Test that device initializes correctly in GCM mode and increments GCM IV."""
  200. device = create_encrypted_device()
  201. device._encryption_mode = AESMode.GCM
  202. device._iv = b"\x01" * 12
  203. with (
  204. patch.object(device, "_ensure_encryption_initialized") as mock_ensure,
  205. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  206. patch.object(device, "_decrypt") as mock_decrypt,
  207. patch.object(device, "_encrypt") as mock_encrypt,
  208. patch.object(device, "_increment_gcm_iv") as mock_inc_iv,
  209. ):
  210. mock_ensure.return_value = True
  211. mock_encrypt.return_value = ("10203040", "abcd")
  212. mock_send.return_value = b"\x01\x00\x00\x00\x10\x20\x30\x40"
  213. mock_decrypt.return_value = b"\x10\x20\x30\x40"
  214. await device._send_command("570200")
  215. mock_inc_iv.assert_called_once()
  216. @pytest.mark.asyncio
  217. async def test_resolve_encryption_mode_invalid() -> None:
  218. """Test that invalid mode byte raises error."""
  219. device = create_encrypted_device()
  220. with pytest.raises(ValueError, match="Unsupported encryption mode"):
  221. device._resolve_encryption_mode(2)
  222. @pytest.mark.asyncio
  223. async def test_resolve_encryption_mode_missing() -> None:
  224. """Test that missing mode byte raises error."""
  225. device = create_encrypted_device()
  226. with pytest.raises(ValueError, match="Encryption mode byte is missing"):
  227. device._resolve_encryption_mode(None)
  228. @pytest.mark.asyncio
  229. async def test_resolve_encryption_mode_conflict() -> None:
  230. """Test that conflicting encryption modes raise error."""
  231. device = create_encrypted_device()
  232. device._encryption_mode = AESMode.CTR
  233. with pytest.raises(
  234. ValueError,
  235. match="Conflicting encryption modes detected: CTR vs GCM",
  236. ):
  237. device._resolve_encryption_mode(1)
  238. @pytest.mark.asyncio
  239. async def test_increment_gcm_iv() -> None:
  240. """Test GCM IV increment logic."""
  241. device = create_encrypted_device()
  242. device._encryption_mode = AESMode.GCM
  243. device._iv = b"\x00" * 11 + b"\x01"
  244. device._increment_gcm_iv()
  245. assert device._iv == b"\x00" * 11 + b"\x02"
  246. assert device._cipher is None
  247. @pytest.mark.asyncio
  248. @pytest.mark.parametrize(
  249. ("initial_iv", "expected_exception", "expected_message"),
  250. [
  251. (None, RuntimeError, "Cannot increment GCM IV: IV is None"),
  252. (
  253. b"\x00" * 10,
  254. RuntimeError,
  255. "Cannot increment GCM IV: IV length is not 12 bytes",
  256. ),
  257. ],
  258. )
  259. async def test_increment_gcm_iv_invalid(
  260. initial_iv, expected_exception, expected_message
  261. ) -> None:
  262. """Test GCM IV increment with invalid IV states."""
  263. device = create_encrypted_device()
  264. device._encryption_mode = AESMode.GCM
  265. device._iv = initial_iv
  266. with pytest.raises(expected_exception, match=expected_message):
  267. device._increment_gcm_iv()
  268. @pytest.mark.asyncio
  269. async def test_gcm_encrypt_decrypt_without_finalize() -> None:
  270. """Test GCM encrypt/decrypt works without finalize in decrypt."""
  271. device = create_encrypted_device()
  272. device._encryption_mode = AESMode.GCM
  273. device._iv = b"\x10" * 12
  274. ciphertext_hex, _ = device._encrypt("48656c6c6f")
  275. decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex))
  276. assert decrypted.hex() == "48656c6c6f"
  277. @pytest.mark.asyncio
  278. async def test_ensure_encryption_initialized_failure() -> None:
  279. """Test _ensure_encryption_initialized when IV initialization fails."""
  280. device = create_encrypted_device()
  281. async with device._operation_lock:
  282. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  283. # Return failure response
  284. mock_send.return_value = b"\x00"
  285. result = await device._ensure_encryption_initialized()
  286. assert result is False
  287. assert device._iv is None
  288. @pytest.mark.asyncio
  289. async def test_encrypt_decrypt_with_valid_iv() -> None:
  290. """Test encryption and decryption with valid IV."""
  291. device = create_encrypted_device()
  292. device._iv = b"\x00" * 16 # Use zeros for predictable test
  293. # Test encryption
  294. ciphertext_hex, header_hex = device._encrypt("48656c6c6f") # "Hello" in hex
  295. assert isinstance(ciphertext_hex, str)
  296. assert isinstance(header_hex, str)
  297. assert len(ciphertext_hex) > 0
  298. # Test decryption
  299. decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex))
  300. assert decrypted.hex() == "48656c6c6f"
  301. @pytest.mark.asyncio
  302. async def test_encrypt_with_none_iv() -> None:
  303. """Test that encryption raises error when IV is None."""
  304. device = create_encrypted_device()
  305. device._iv = None
  306. with pytest.raises(RuntimeError, match="Cannot encrypt: IV is None"):
  307. device._encrypt("48656c6c6f")
  308. @pytest.mark.asyncio
  309. async def test_decrypt_with_none_iv() -> None:
  310. """Test that decryption raises error when IV is None."""
  311. device = create_encrypted_device()
  312. device._iv = None
  313. with pytest.raises(RuntimeError, match="Cannot decrypt: IV is None"):
  314. device._decrypt(bytearray.fromhex("48656c6c6f"))
  315. @pytest.mark.asyncio
  316. async def test_get_cipher_with_none_iv() -> None:
  317. """Test that _get_cipher raises error when IV is None."""
  318. device = create_encrypted_device()
  319. device._iv = None
  320. with pytest.raises(RuntimeError, match="Cannot create cipher: IV is None"):
  321. device._get_cipher()
  322. @pytest.mark.asyncio
  323. async def test_execute_disconnect_clears_encryption_state() -> None:
  324. """Test that disconnect properly clears encryption state."""
  325. device = create_encrypted_device()
  326. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  327. device._cipher = None # type: ignore[assignment]
  328. device._encryption_mode = AESMode.CTR
  329. # Mock client
  330. mock_client = AsyncMock()
  331. device._client = mock_client
  332. with patch.object(device, "_execute_disconnect_with_lock") as mock_disconnect:
  333. await device._execute_disconnect()
  334. assert device._iv is None
  335. assert device._cipher is None
  336. assert device._encryption_mode is None
  337. mock_disconnect.assert_called_once()
  338. @pytest.mark.asyncio
  339. async def test_concurrent_commands_with_same_device() -> None:
  340. """Test multiple concurrent commands on the same device."""
  341. device = create_encrypted_device()
  342. # Pre-initialize IV (16 bytes for AES CTR mode)
  343. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  344. with (
  345. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  346. patch.object(device, "_encrypt") as mock_encrypt,
  347. patch.object(device, "_decrypt") as mock_decrypt,
  348. ):
  349. mock_encrypt.return_value = ("656e63727970746564", "abcd")
  350. mock_decrypt.return_value = b"response"
  351. mock_send.return_value = b"\x01\x00\x00\x00data"
  352. # Send multiple commands concurrently
  353. tasks = [
  354. device._send_command("570200"),
  355. device._send_command("570201"),
  356. device._send_command("570202"),
  357. ]
  358. results = await asyncio.gather(*tasks)
  359. # All commands should succeed
  360. assert all(result == b"\x01response" for result in results)
  361. assert mock_send.call_count == 3
  362. @pytest.mark.asyncio
  363. async def test_command_retry_with_encryption() -> None:
  364. """Test command retry logic with encrypted commands."""
  365. device = create_encrypted_device()
  366. device._retry_count = 2
  367. # Pre-initialize IV (16 bytes for AES CTR mode)
  368. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  369. with (
  370. patch.object(device, "_send_command_locked") as mock_send_locked,
  371. patch.object(device, "_ensure_connected"),
  372. patch.object(device, "_encrypt") as mock_encrypt,
  373. patch.object(device, "_decrypt") as mock_decrypt,
  374. ):
  375. mock_encrypt.return_value = ("656e63727970746564", "abcd")
  376. mock_decrypt.return_value = b"response"
  377. # First attempt fails, second succeeds
  378. mock_send_locked.side_effect = [
  379. BleakDBusError("org.bluez.Error", []),
  380. b"\x01\x00\x00\x00data",
  381. ]
  382. result = await device._send_command("570200")
  383. assert result == b"\x01response"
  384. assert mock_send_locked.call_count == 2
  385. @pytest.mark.asyncio
  386. async def test_empty_data_encryption_decryption() -> None:
  387. """Test encryption/decryption of empty data."""
  388. device = create_encrypted_device()
  389. device._iv = b"\x00" * 16
  390. # Test empty encryption
  391. encrypted = device._encrypt("")
  392. assert encrypted == ("", "")
  393. # Test empty decryption
  394. decrypted = device._decrypt(bytearray())
  395. assert decrypted == b""
  396. @pytest.mark.asyncio
  397. async def test_decrypt_with_none_iv_during_disconnect() -> None:
  398. """Test that decryption returns empty bytes when IV is None during expected disconnect."""
  399. device = create_encrypted_device()
  400. # Simulate disconnection in progress
  401. device._expected_disconnect = True
  402. device._iv = None
  403. # Should return empty bytes instead of raising
  404. result = device._decrypt(bytearray(b"encrypted_data"))
  405. assert result == b""
  406. # Verify it still raises when not disconnecting
  407. device._expected_disconnect = False
  408. with pytest.raises(RuntimeError, match="Cannot decrypt: IV is None"):
  409. device._decrypt(bytearray(b"encrypted_data"))