test_encrypted_device.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. """Tests for SwitchbotEncryptedDevice base class."""
  2. from __future__ import annotations
  3. import asyncio
  4. from typing import Any
  5. from unittest.mock import AsyncMock, patch
  6. import pytest
  7. from bleak.exc import BleakDBusError
  8. from switchbot import SwitchbotModel
  9. from switchbot.devices.device import AESMode, SwitchbotEncryptedDevice
  10. from .test_adv_parser import generate_ble_device
  11. class MockEncryptedDevice(SwitchbotEncryptedDevice):
  12. """Mock encrypted device for testing."""
  13. def __init__(self, *args: Any, **kwargs: Any) -> None:
  14. super().__init__(*args, **kwargs)
  15. self.update_count: int = 0
  16. async def update(self, interface: int | None = None) -> None:
  17. self.update_count += 1
  18. def create_encrypted_device(
  19. model: SwitchbotModel = SwitchbotModel.LOCK,
  20. ) -> MockEncryptedDevice:
  21. """Create an encrypted device for testing."""
  22. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  23. return MockEncryptedDevice(
  24. ble_device, "01", "0123456789abcdef0123456789abcdef", model=model
  25. )
  26. @pytest.mark.asyncio
  27. async def test_encrypted_device_init() -> None:
  28. """Test encrypted device initialization."""
  29. device = create_encrypted_device()
  30. assert device._key_id == "01"
  31. assert device._encryption_key == bytearray.fromhex(
  32. "0123456789abcdef0123456789abcdef"
  33. )
  34. assert device._iv is None
  35. assert device._cipher is None
  36. @pytest.mark.asyncio
  37. async def test_encrypted_device_init_validation() -> None:
  38. """Test encrypted device initialization with invalid parameters."""
  39. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  40. # Test empty key_id
  41. with pytest.raises(ValueError, match="key_id is missing"):
  42. MockEncryptedDevice(
  43. ble_device,
  44. "",
  45. "0123456789abcdef0123456789abcdef",
  46. model=SwitchbotModel.LOCK,
  47. )
  48. # Test invalid key_id length
  49. with pytest.raises(ValueError, match="key_id is invalid"):
  50. MockEncryptedDevice(
  51. ble_device,
  52. "1",
  53. "0123456789abcdef0123456789abcdef",
  54. model=SwitchbotModel.LOCK,
  55. )
  56. # Test empty encryption_key
  57. with pytest.raises(ValueError, match="encryption_key is missing"):
  58. MockEncryptedDevice(ble_device, "01", "", model=SwitchbotModel.LOCK)
  59. # Test invalid encryption_key length
  60. with pytest.raises(ValueError, match="encryption_key is invalid"):
  61. MockEncryptedDevice(
  62. ble_device, "01", "0123456789abcdef", model=SwitchbotModel.LOCK
  63. )
  64. @pytest.mark.asyncio
  65. async def test_send_command_unencrypted() -> None:
  66. """Test sending unencrypted command."""
  67. device = create_encrypted_device()
  68. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  69. mock_send.return_value = b"\x01\x00\x00\x00"
  70. result = await device._send_command("570200", encrypt=False)
  71. assert result == b"\x01\x00\x00\x00"
  72. mock_send.assert_called_once()
  73. # Verify the key was padded with zeros for unencrypted command
  74. call_args = mock_send.call_args[0]
  75. assert call_args[0] == "570000000200" # Original key with zeros inserted
  76. @pytest.mark.asyncio
  77. async def test_send_command_encrypted_success() -> None:
  78. """Test successful encrypted command."""
  79. device = create_encrypted_device()
  80. # Mock the connection and command execution
  81. with (
  82. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  83. patch.object(device, "_decrypt") as mock_decrypt,
  84. ):
  85. mock_decrypt.return_value = b"decrypted_response"
  86. # First call is for IV initialization, second is for the actual command
  87. mock_send.side_effect = [
  88. b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0", # IV response (16 bytes)
  89. b"\x01\x00\x00\x00encrypted_response", # Command response
  90. ]
  91. result = await device._send_command("570200", encrypt=True)
  92. assert result is not None
  93. assert mock_send.call_count == 2
  94. # Verify IV was initialized
  95. assert device._iv is not None
  96. @pytest.mark.asyncio
  97. async def test_send_command_iv_already_initialized() -> None:
  98. """Test sending encrypted command when IV is already initialized."""
  99. device = create_encrypted_device()
  100. # Pre-set the IV
  101. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  102. with (
  103. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  104. patch.object(device, "_encrypt") as mock_encrypt,
  105. patch.object(device, "_decrypt") as mock_decrypt,
  106. ):
  107. mock_encrypt.return_value = (
  108. "656e637279707465645f64617461", # "encrypted_data" in hex
  109. "abcd",
  110. )
  111. mock_decrypt.return_value = b"decrypted_response"
  112. mock_send.return_value = b"\x01\x00\x00\x00encrypted_response"
  113. result = await device._send_command("570200", encrypt=True)
  114. assert result == b"\x01decrypted_response"
  115. # Should only call once since IV is already initialized
  116. mock_send.assert_called_once()
  117. mock_encrypt.assert_called_once()
  118. mock_decrypt.assert_called_once()
  119. @pytest.mark.asyncio
  120. async def test_iv_race_condition_during_disconnect() -> None:
  121. """Test that commands during disconnect are handled properly."""
  122. device = create_encrypted_device()
  123. # Pre-set the IV
  124. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78"
  125. # Mock the connection
  126. mock_client = AsyncMock()
  127. mock_client.is_connected = True
  128. device._client = mock_client
  129. async def simulate_disconnect() -> None:
  130. """Simulate disconnect happening during command execution."""
  131. await asyncio.sleep(0.01) # Small delay
  132. await device._execute_disconnect()
  133. with (
  134. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  135. patch.object(device, "_ensure_connected"),
  136. patch.object(device, "_encrypt") as mock_encrypt,
  137. patch.object(device, "_decrypt") as mock_decrypt,
  138. ):
  139. mock_encrypt.return_value = ("656e63727970746564", "abcd")
  140. mock_decrypt.return_value = b"response"
  141. mock_send.return_value = b"\x01\x00\x00\x00response"
  142. # Start command and disconnect concurrently
  143. command_task = asyncio.create_task(device._send_command("570200"))
  144. disconnect_task = asyncio.create_task(simulate_disconnect())
  145. # Both should complete without error
  146. result, _ = await asyncio.gather(
  147. command_task, disconnect_task, return_exceptions=True
  148. )
  149. # Command should have completed successfully
  150. assert isinstance(result, bytes) or result is None
  151. # IV should be cleared after disconnect
  152. assert device._iv is None
  153. @pytest.mark.asyncio
  154. async def test_ensure_encryption_initialized_with_lock_held() -> None:
  155. """Test that _ensure_encryption_initialized properly handles the operation lock."""
  156. device = create_encrypted_device()
  157. # Acquire the operation lock
  158. async with device._operation_lock:
  159. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  160. mock_send.return_value = b"\x01\x00\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  161. result = await device._ensure_encryption_initialized()
  162. assert result is True
  163. assert (
  164. device._iv
  165. == b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  166. )
  167. assert device._cipher is None # Should be reset when IV changes
  168. @pytest.mark.asyncio
  169. async def test_ensure_encryption_initialized_sets_gcm_mode() -> None:
  170. """Test that GCM mode is detected from device response."""
  171. device = create_encrypted_device()
  172. gcm_iv = b"\x01" * 12
  173. response = b"\x01\x00\x01\x00" + gcm_iv + b"\x00\x00\x00\x00"
  174. async with device._operation_lock:
  175. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  176. mock_send.return_value = response
  177. result = await device._ensure_encryption_initialized()
  178. assert result is True
  179. assert device._encryption_mode == AESMode.GCM
  180. assert device._iv == gcm_iv
  181. @pytest.mark.asyncio
  182. async def test_ensure_encryption_initialized_invalid_iv_length_gcm() -> None:
  183. """Test that invalid IV length for GCM mode returns False."""
  184. device = create_encrypted_device()
  185. # GCM expects 12 bytes IV, but response has wrong length (only 8 bytes after trimming)
  186. response = b"\x01\x00\x01\x00" + b"\x01" * 8 + b"\x00\x00\x00\x00"
  187. async with device._operation_lock:
  188. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  189. mock_send.return_value = response
  190. result = await device._ensure_encryption_initialized()
  191. assert result is False
  192. assert device._iv is None
  193. @pytest.mark.asyncio
  194. async def test_ensure_encryption_initialized_invalid_iv_length_ctr() -> None:
  195. """Test that invalid IV length for CTR mode returns False."""
  196. device = create_encrypted_device()
  197. # CTR expects 16 bytes IV, but response has only 8 bytes
  198. response = b"\x01\x00\x00\x00" + b"\x01" * 8
  199. async with device._operation_lock:
  200. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  201. mock_send.return_value = response
  202. result = await device._ensure_encryption_initialized()
  203. assert result is False
  204. assert device._iv is None
  205. @pytest.mark.asyncio
  206. async def test_device_with_gcm_mode() -> None:
  207. """Test that device initializes correctly in GCM mode and increments GCM IV."""
  208. device = create_encrypted_device()
  209. device._encryption_mode = AESMode.GCM
  210. device._iv = b"\x01" * 12
  211. with (
  212. patch.object(device, "_ensure_encryption_initialized") as mock_ensure,
  213. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  214. patch.object(device, "_decrypt") as mock_decrypt,
  215. patch.object(device, "_encrypt") as mock_encrypt,
  216. patch.object(device, "_increment_gcm_iv") as mock_inc_iv,
  217. ):
  218. mock_ensure.return_value = True
  219. mock_encrypt.return_value = ("10203040", "abcd")
  220. mock_send.return_value = b"\x01\x00\x00\x00\x10\x20\x30\x40"
  221. mock_decrypt.return_value = b"\x10\x20\x30\x40"
  222. await device._send_command("570200")
  223. mock_inc_iv.assert_called_once()
  224. @pytest.mark.asyncio
  225. async def test_resolve_encryption_mode_invalid() -> None:
  226. """Test that invalid mode byte raises error."""
  227. device = create_encrypted_device()
  228. with pytest.raises(ValueError, match="Unsupported encryption mode"):
  229. device._resolve_encryption_mode(2)
  230. @pytest.mark.asyncio
  231. async def test_resolve_encryption_mode_missing() -> None:
  232. """Test that missing mode byte raises error."""
  233. device = create_encrypted_device()
  234. with pytest.raises(ValueError, match="Encryption mode byte is missing"):
  235. device._resolve_encryption_mode(None)
  236. @pytest.mark.asyncio
  237. async def test_resolve_encryption_mode_conflict() -> None:
  238. """Test that conflicting encryption modes raise error."""
  239. device = create_encrypted_device()
  240. device._encryption_mode = AESMode.CTR
  241. with pytest.raises(
  242. ValueError,
  243. match="Conflicting encryption modes detected: CTR vs GCM",
  244. ):
  245. device._resolve_encryption_mode(1)
  246. @pytest.mark.asyncio
  247. async def test_increment_gcm_iv() -> None:
  248. """Test GCM IV increment logic."""
  249. device = create_encrypted_device()
  250. device._encryption_mode = AESMode.GCM
  251. device._iv = b"\x00" * 11 + b"\x01"
  252. device._increment_gcm_iv()
  253. assert device._iv == b"\x00" * 11 + b"\x02"
  254. assert device._cipher is None
  255. @pytest.mark.asyncio
  256. @pytest.mark.parametrize(
  257. ("initial_iv", "expected_exception", "expected_message"),
  258. [
  259. (None, RuntimeError, "Cannot increment GCM IV: IV is None"),
  260. (
  261. b"\x00" * 10,
  262. RuntimeError,
  263. "Cannot increment GCM IV: IV length is not 12 bytes",
  264. ),
  265. ],
  266. )
  267. async def test_increment_gcm_iv_invalid(
  268. initial_iv, expected_exception, expected_message
  269. ) -> None:
  270. """Test GCM IV increment with invalid IV states."""
  271. device = create_encrypted_device()
  272. device._encryption_mode = AESMode.GCM
  273. device._iv = initial_iv
  274. with pytest.raises(expected_exception, match=expected_message):
  275. device._increment_gcm_iv()
  276. @pytest.mark.asyncio
  277. async def test_gcm_encrypt_decrypt_without_finalize() -> None:
  278. """Test GCM encrypt/decrypt works without finalize in decrypt."""
  279. device = create_encrypted_device()
  280. device._encryption_mode = AESMode.GCM
  281. device._iv = b"\x10" * 12
  282. ciphertext_hex, _ = device._encrypt("48656c6c6f")
  283. decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex))
  284. assert decrypted.hex() == "48656c6c6f"
  285. @pytest.mark.asyncio
  286. async def test_ensure_encryption_initialized_failure() -> None:
  287. """Test _ensure_encryption_initialized when IV initialization fails."""
  288. device = create_encrypted_device()
  289. async with device._operation_lock:
  290. with patch.object(device, "_send_command_locked_with_retry") as mock_send:
  291. # Return failure response
  292. mock_send.return_value = b"\x00"
  293. result = await device._ensure_encryption_initialized()
  294. assert result is False
  295. assert device._iv is None
  296. @pytest.mark.asyncio
  297. async def test_encrypt_decrypt_with_valid_iv() -> None:
  298. """Test encryption and decryption with valid IV."""
  299. device = create_encrypted_device()
  300. device._iv = b"\x00" * 16 # Use zeros for predictable test
  301. # Test encryption
  302. ciphertext_hex, header_hex = device._encrypt("48656c6c6f") # "Hello" in hex
  303. assert isinstance(ciphertext_hex, str)
  304. assert isinstance(header_hex, str)
  305. assert len(ciphertext_hex) > 0
  306. # Test decryption
  307. decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex))
  308. assert decrypted.hex() == "48656c6c6f"
  309. @pytest.mark.asyncio
  310. async def test_encrypt_with_none_iv() -> None:
  311. """Test that encryption raises error when IV is None."""
  312. device = create_encrypted_device()
  313. device._iv = None
  314. with pytest.raises(RuntimeError, match="Cannot encrypt: IV is None"):
  315. device._encrypt("48656c6c6f")
  316. @pytest.mark.asyncio
  317. async def test_decrypt_with_none_iv() -> None:
  318. """Test that decryption raises error when IV is None."""
  319. device = create_encrypted_device()
  320. device._iv = None
  321. with pytest.raises(RuntimeError, match="Cannot decrypt: IV is None"):
  322. device._decrypt(bytearray.fromhex("48656c6c6f"))
  323. @pytest.mark.asyncio
  324. async def test_get_cipher_with_none_iv() -> None:
  325. """Test that _get_cipher raises error when IV is None."""
  326. device = create_encrypted_device()
  327. device._iv = None
  328. with pytest.raises(RuntimeError, match="Cannot create cipher: IV is None"):
  329. device._get_cipher()
  330. @pytest.mark.asyncio
  331. async def test_execute_disconnect_clears_encryption_state() -> None:
  332. """Test that disconnect properly clears encryption state."""
  333. device = create_encrypted_device()
  334. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  335. device._cipher = None # type: ignore[assignment]
  336. device._encryption_mode = AESMode.CTR
  337. # Mock client
  338. mock_client = AsyncMock()
  339. device._client = mock_client
  340. with patch.object(device, "_execute_disconnect_with_lock") as mock_disconnect:
  341. await device._execute_disconnect()
  342. assert device._iv is None
  343. assert device._cipher is None
  344. assert device._encryption_mode is None
  345. mock_disconnect.assert_called_once()
  346. @pytest.mark.asyncio
  347. async def test_concurrent_commands_with_same_device() -> None:
  348. """Test multiple concurrent commands on the same device."""
  349. device = create_encrypted_device()
  350. # Pre-initialize IV (16 bytes for AES CTR mode)
  351. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  352. with (
  353. patch.object(device, "_send_command_locked_with_retry") as mock_send,
  354. patch.object(device, "_encrypt") as mock_encrypt,
  355. patch.object(device, "_decrypt") as mock_decrypt,
  356. ):
  357. mock_encrypt.return_value = ("656e63727970746564", "abcd")
  358. mock_decrypt.return_value = b"response"
  359. mock_send.return_value = b"\x01\x00\x00\x00data"
  360. # Send multiple commands concurrently
  361. tasks = [
  362. device._send_command("570200"),
  363. device._send_command("570201"),
  364. device._send_command("570202"),
  365. ]
  366. results = await asyncio.gather(*tasks)
  367. # All commands should succeed
  368. assert all(result == b"\x01response" for result in results)
  369. assert mock_send.call_count == 3
  370. @pytest.mark.asyncio
  371. async def test_command_retry_with_encryption() -> None:
  372. """Test command retry logic with encrypted commands."""
  373. device = create_encrypted_device()
  374. device._retry_count = 2
  375. # Pre-initialize IV (16 bytes for AES CTR mode)
  376. device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
  377. with (
  378. patch.object(device, "_send_command_locked") as mock_send_locked,
  379. patch.object(device, "_ensure_connected"),
  380. patch.object(device, "_encrypt") as mock_encrypt,
  381. patch.object(device, "_decrypt") as mock_decrypt,
  382. ):
  383. mock_encrypt.return_value = ("656e63727970746564", "abcd")
  384. mock_decrypt.return_value = b"response"
  385. # First attempt fails, second succeeds
  386. mock_send_locked.side_effect = [
  387. BleakDBusError("org.bluez.Error", []),
  388. b"\x01\x00\x00\x00data",
  389. ]
  390. result = await device._send_command("570200")
  391. assert result == b"\x01response"
  392. assert mock_send_locked.call_count == 2
  393. @pytest.mark.asyncio
  394. async def test_empty_data_encryption_decryption() -> None:
  395. """Test encryption/decryption of empty data."""
  396. device = create_encrypted_device()
  397. device._iv = b"\x00" * 16
  398. # Test empty encryption
  399. encrypted = device._encrypt("")
  400. assert encrypted == ("", "")
  401. # Test empty decryption
  402. decrypted = device._decrypt(bytearray())
  403. assert decrypted == b""
  404. @pytest.mark.asyncio
  405. async def test_verify_encryption_key_falls_back_to_classvar() -> None:
  406. """verify_encryption_key resolves `model` from `cls._model` when omitted."""
  407. class ClassvarModelDevice(MockEncryptedDevice):
  408. _model = SwitchbotModel.LOCK
  409. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  410. with patch.object(
  411. ClassvarModelDevice,
  412. "get_basic_info",
  413. new_callable=AsyncMock,
  414. return_value={"ok": True},
  415. ):
  416. result = await ClassvarModelDevice.verify_encryption_key(
  417. ble_device, "01", "0123456789abcdef0123456789abcdef"
  418. )
  419. assert result is True
  420. @pytest.mark.asyncio
  421. async def test_verify_encryption_key_without_model_or_classvar_raises() -> None:
  422. """verify_encryption_key raises when neither `model=` nor `cls._model` is set."""
  423. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  424. with pytest.raises(ValueError, match="model must be provided"):
  425. await MockEncryptedDevice.verify_encryption_key(
  426. ble_device, "01", "0123456789abcdef0123456789abcdef"
  427. )
  428. def test_init_without_model_or_classvar_raises() -> None:
  429. """__init__ raises when neither `model=` nor `cls._model` is set."""
  430. ble_device = generate_ble_device("aa:bb:cc:dd:ee:ff", "Test Device")
  431. with pytest.raises(ValueError, match="model must be provided"):
  432. MockEncryptedDevice(ble_device, "01", "0123456789abcdef0123456789abcdef")
  433. @pytest.mark.asyncio
  434. async def test_decrypt_with_none_iv_during_disconnect() -> None:
  435. """Test that decryption returns empty bytes when IV is None during expected disconnect."""
  436. device = create_encrypted_device()
  437. # Simulate disconnection in progress
  438. device._expected_disconnect = True
  439. device._iv = None
  440. # Should return empty bytes instead of raising
  441. result = device._decrypt(bytearray(b"encrypted_data"))
  442. assert result == b""
  443. # Verify it still raises when not disconnecting
  444. device._expected_disconnect = False
  445. with pytest.raises(RuntimeError, match="Cannot decrypt: IV is None"):
  446. device._decrypt(bytearray(b"encrypted_data"))