Browse Source

Newly implemented GCM encryption method (#442)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: J. Nick Koston <nick@home-assistant.io>
Retha Runolfsson 5 hours ago
parent
commit
02d7a3828a
2 changed files with 273 additions and 24 deletions
  1. 99 12
      switchbot/devices/device.py
  2. 174 12
      tests/test_encrypted_device.py

+ 99 - 12
switchbot/devices/device.py

@@ -8,6 +8,7 @@ import logging
 import time
 from collections.abc import Callable
 from dataclasses import replace
+from enum import IntEnum
 from typing import Any, TypeVar, cast
 from uuid import UUID
 
@@ -142,6 +143,21 @@ class SwitchbotOperationError(Exception):
     """Raised when an operation fails."""
 
 
+class AESMode(IntEnum):
+    """Supported AES modes for encrypted devices."""
+
+    CTR = 0
+    GCM = 1
+
+
+def _normalize_encryption_mode(mode: int) -> AESMode:
+    """Normalize encryption mode to AESMode (only 0/1 allowed)."""
+    try:
+        return AESMode(mode)
+    except (TypeError, ValueError) as exc:
+        raise ValueError(f"Unsupported encryption mode: {mode}") from exc
+
+
 def _sb_uuid(comms_type: str = "service") -> UUID | str:
     """Return Switchbot UUID."""
     _uuid = {"tx": "002", "rx": "003", "service": "d00"}
@@ -982,7 +998,8 @@ class SwitchbotEncryptedDevice(SwitchbotDevice):
         self._key_id = key_id
         self._encryption_key = bytearray.fromhex(encryption_key)
         self._iv: bytes | None = None
-        self._cipher: bytes | None = None
+        self._cipher: Cipher | None = None
+        self._encryption_mode: AESMode | None = None
         super().__init__(device, None, interface, **kwargs)
         self._model = model
 
@@ -1081,9 +1098,8 @@ class SwitchbotEncryptedDevice(SwitchbotDevice):
                 _LOGGER.error("Failed to initialize encryption")
                 return None
 
-            encrypted = (
-                key[:2] + self._key_id + self._iv[0:2].hex() + self._encrypt(key[2:])
-            )
+            ciphertext_hex, header_hex = self._encrypt(key[2:])
+            encrypted = key[:2] + self._key_id + header_hex + ciphertext_hex
             command = bytearray.fromhex(self._commandkey(encrypted))
             _LOGGER.debug("%s: Scheduling command %s", self.name, command.hex())
             max_attempts = retry + 1
@@ -1093,7 +1109,10 @@ class SwitchbotEncryptedDevice(SwitchbotDevice):
             )
             if result is None:
                 return None
-            return result[:1] + self._decrypt(result[4:])
+            decrypted = self._decrypt(result[4:])
+            if self._encryption_mode == AESMode.GCM:
+                self._increment_gcm_iv()
+            return result[:1] + decrypted
 
     async def _ensure_encryption_initialized(self) -> bool:
         """Ensure encryption is initialized, must be called with operation lock held."""
@@ -1117,34 +1136,71 @@ class SwitchbotEncryptedDevice(SwitchbotDevice):
             return False
 
         if ok := self._check_command_result(result, 0, {1}):
-            self._iv = result[4:]
+            _LOGGER.debug("%s: Encryption init response: %s", self.name, result.hex())
+            mode_byte = result[2] if len(result) > 2 else None
+            self._resolve_encryption_mode(mode_byte)
+            if self._encryption_mode == AESMode.GCM:
+                iv = result[4:-4]
+                expected_iv_len = 12
+            else:
+                iv = result[4:]
+                expected_iv_len = 16
+            if len(iv) != expected_iv_len:
+                _LOGGER.error(
+                    "%s: Invalid IV length %d for mode %s (expected %d)",
+                    self.name,
+                    len(iv),
+                    self._encryption_mode.name,
+                    expected_iv_len,
+                )
+                return False
+            self._iv = iv
             self._cipher = None  # Reset cipher when IV changes
             _LOGGER.debug("%s: Encryption initialized successfully", self.name)
 
         return ok
 
     async def _execute_disconnect(self) -> None:
+        """
+        Reset encryption state and disconnect.
+
+        Clears IV, cipher, and encryption mode so they can be
+        re-detected on the next connection (e.g., after firmware update).
+        """
         async with self._connect_lock:
             self._iv = None
             self._cipher = None
+            self._encryption_mode = None
             await self._execute_disconnect_with_lock()
 
     def _get_cipher(self) -> Cipher:
         if self._cipher is None:
             if self._iv is None:
                 raise RuntimeError("Cannot create cipher: IV is None")
-            self._cipher = Cipher(
-                algorithms.AES128(self._encryption_key), modes.CTR(self._iv)
-            )
+            if self._encryption_mode == AESMode.GCM:
+                self._cipher = Cipher(
+                    algorithms.AES128(self._encryption_key), modes.GCM(self._iv)
+                )
+            else:
+                self._cipher = Cipher(
+                    algorithms.AES128(self._encryption_key), modes.CTR(self._iv)
+                )
         return self._cipher
 
-    def _encrypt(self, data: str) -> str:
+    def _encrypt(self, data: str) -> tuple[str, str]:
         if len(data) == 0:
-            return ""
+            return "", ""
         if self._iv is None:
             raise RuntimeError("Cannot encrypt: IV is None")
         encryptor = self._get_cipher().encryptor()
-        return (encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()).hex()
+        ciphertext = encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()
+        if self._encryption_mode == AESMode.GCM:
+            header_hex = encryptor.tag[:2].hex()
+            # GCM cipher is single-use; clear it so _get_cipher() creates a fresh one
+            self._cipher = None
+        else:
+            header_hex = self._iv[0:2].hex()
+        return ciphertext.hex(), header_hex
 
     def _decrypt(self, data: bytearray) -> bytes:
         if len(data) == 0:
@@ -1157,9 +1213,40 @@ class SwitchbotEncryptedDevice(SwitchbotDevice):
                 )
                 return b""
             raise RuntimeError("Cannot decrypt: IV is None")
+        if self._encryption_mode == AESMode.GCM:
+            # Firmware only returns a 2-byte partial tag which can't be used for
+            # verification. Use a dummy 16-byte tag and skip finalize() since
+            # authentication is handled by the firmware.
+            decryptor = Cipher(
+                algorithms.AES128(self._encryption_key),
+                modes.GCM(self._iv, b"\x00" * 16),
+            ).decryptor()
+            return decryptor.update(data)
         decryptor = self._get_cipher().decryptor()
         return decryptor.update(data) + decryptor.finalize()
 
+    def _increment_gcm_iv(self) -> None:
+        """Increment GCM IV by 1 (big-endian). Called after each encrypted command."""
+        if self._iv is None:
+            raise RuntimeError("Cannot increment GCM IV: IV is None")
+        if len(self._iv) != 12:
+            raise RuntimeError("Cannot increment GCM IV: IV length is not 12 bytes")
+        iv_int = int.from_bytes(self._iv, "big") + 1
+        self._iv = iv_int.to_bytes(12, "big")
+        self._cipher = None
+
+    def _resolve_encryption_mode(self, mode_byte: int | None) -> None:
+        """Resolve encryption mode from device response when available."""
+        if mode_byte is None:
+            raise ValueError("Encryption mode byte is missing")
+        detected_mode = _normalize_encryption_mode(mode_byte)
+        if self._encryption_mode is not None and self._encryption_mode != detected_mode:
+            raise ValueError(
+                f"Conflicting encryption modes detected: {self._encryption_mode.name} vs {detected_mode.name}"
+            )
+        self._encryption_mode = detected_mode
+        _LOGGER.debug("%s: Detected encryption mode: %s", self.name, detected_mode.name)
+
 
 class SwitchbotDeviceOverrideStateDuringConnection(SwitchbotBaseDevice):
     """

+ 174 - 12
tests/test_encrypted_device.py

@@ -10,9 +10,7 @@ import pytest
 from bleak.exc import BleakDBusError
 
 from switchbot import SwitchbotModel
-from switchbot.devices.device import (
-    SwitchbotEncryptedDevice,
-)
+from switchbot.devices.device import AESMode, SwitchbotEncryptedDevice
 
 from .test_adv_parser import generate_ble_device
 
@@ -133,7 +131,8 @@ async def test_send_command_iv_already_initialized() -> None:
         patch.object(device, "_decrypt") as mock_decrypt,
     ):
         mock_encrypt.return_value = (
-            "656e637279707465645f64617461"  # "encrypted_data" in hex
+            "656e637279707465645f64617461",  # "encrypted_data" in hex
+            "abcd",
         )
         mock_decrypt.return_value = b"decrypted_response"
         mock_send.return_value = b"\x01\x00\x00\x00encrypted_response"
@@ -171,7 +170,7 @@ async def test_iv_race_condition_during_disconnect() -> None:
         patch.object(device, "_encrypt") as mock_encrypt,
         patch.object(device, "_decrypt") as mock_decrypt,
     ):
-        mock_encrypt.return_value = "656e63727970746564"  # "encrypted" in hex
+        mock_encrypt.return_value = ("656e63727970746564", "abcd")
         mock_decrypt.return_value = b"response"
         mock_send.return_value = b"\x01\x00\x00\x00response"
 
@@ -210,6 +209,166 @@ async def test_ensure_encryption_initialized_with_lock_held() -> None:
             assert device._cipher is None  # Should be reset when IV changes
 
 
+@pytest.mark.asyncio
+async def test_ensure_encryption_initialized_sets_gcm_mode() -> None:
+    """Test that GCM mode is detected from device response."""
+    device = create_encrypted_device()
+
+    gcm_iv = b"\x01" * 12
+    response = b"\x01\x00\x01\x00" + gcm_iv + b"\x00\x00\x00\x00"
+
+    async with device._operation_lock:
+        with patch.object(device, "_send_command_locked_with_retry") as mock_send:
+            mock_send.return_value = response
+
+            result = await device._ensure_encryption_initialized()
+
+            assert result is True
+            assert device._encryption_mode == AESMode.GCM
+            assert device._iv == gcm_iv
+
+
+@pytest.mark.asyncio
+async def test_ensure_encryption_initialized_invalid_iv_length_gcm() -> None:
+    """Test that invalid IV length for GCM mode returns False."""
+    device = create_encrypted_device()
+
+    # GCM expects 12 bytes IV, but response has wrong length (only 8 bytes after trimming)
+    response = b"\x01\x00\x01\x00" + b"\x01" * 8 + b"\x00\x00\x00\x00"
+
+    async with device._operation_lock:
+        with patch.object(device, "_send_command_locked_with_retry") as mock_send:
+            mock_send.return_value = response
+
+            result = await device._ensure_encryption_initialized()
+
+            assert result is False
+            assert device._iv is None
+
+
+@pytest.mark.asyncio
+async def test_ensure_encryption_initialized_invalid_iv_length_ctr() -> None:
+    """Test that invalid IV length for CTR mode returns False."""
+    device = create_encrypted_device()
+
+    # CTR expects 16 bytes IV, but response has only 8 bytes
+    response = b"\x01\x00\x00\x00" + b"\x01" * 8
+
+    async with device._operation_lock:
+        with patch.object(device, "_send_command_locked_with_retry") as mock_send:
+            mock_send.return_value = response
+
+            result = await device._ensure_encryption_initialized()
+
+            assert result is False
+            assert device._iv is None
+
+
+@pytest.mark.asyncio
+async def test_device_with_gcm_mode() -> None:
+    """Test that device initializes correctly in GCM mode and increments GCM IV."""
+    device = create_encrypted_device()
+    device._encryption_mode = AESMode.GCM
+    device._iv = b"\x01" * 12
+
+    with (
+        patch.object(device, "_ensure_encryption_initialized") as mock_ensure,
+        patch.object(device, "_send_command_locked_with_retry") as mock_send,
+        patch.object(device, "_decrypt") as mock_decrypt,
+        patch.object(device, "_encrypt") as mock_encrypt,
+        patch.object(device, "_increment_gcm_iv") as mock_inc_iv,
+    ):
+        mock_ensure.return_value = True
+        mock_encrypt.return_value = ("10203040", "abcd")
+        mock_send.return_value = b"\x01\x00\x00\x00\x10\x20\x30\x40"
+        mock_decrypt.return_value = b"\x10\x20\x30\x40"
+
+        await device._send_command("570200")
+
+        mock_inc_iv.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_resolve_encryption_mode_invalid() -> None:
+    """Test that invalid mode byte raises error."""
+    device = create_encrypted_device()
+
+    with pytest.raises(ValueError, match="Unsupported encryption mode"):
+        device._resolve_encryption_mode(2)
+
+
+@pytest.mark.asyncio
+async def test_resolve_encryption_mode_missing() -> None:
+    """Test that missing mode byte raises error."""
+    device = create_encrypted_device()
+
+    with pytest.raises(ValueError, match="Encryption mode byte is missing"):
+        device._resolve_encryption_mode(None)
+
+
+@pytest.mark.asyncio
+async def test_resolve_encryption_mode_conflict() -> None:
+    """Test that conflicting encryption modes raise error."""
+    device = create_encrypted_device()
+    device._encryption_mode = AESMode.CTR
+
+    with pytest.raises(
+        ValueError,
+        match="Conflicting encryption modes detected: CTR vs GCM",
+    ):
+        device._resolve_encryption_mode(1)
+
+
+@pytest.mark.asyncio
+async def test_increment_gcm_iv() -> None:
+    """Test GCM IV increment logic."""
+    device = create_encrypted_device()
+    device._encryption_mode = AESMode.GCM
+    device._iv = b"\x00" * 11 + b"\x01"
+
+    device._increment_gcm_iv()
+
+    assert device._iv == b"\x00" * 11 + b"\x02"
+    assert device._cipher is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    ("initial_iv", "expected_exception", "expected_message"),
+    [
+        (None, RuntimeError, "Cannot increment GCM IV: IV is None"),
+        (
+            b"\x00" * 10,
+            RuntimeError,
+            "Cannot increment GCM IV: IV length is not 12 bytes",
+        ),
+    ],
+)
+async def test_increment_gcm_iv_invalid(
+    initial_iv, expected_exception, expected_message
+) -> None:
+    """Test GCM IV increment with invalid IV states."""
+    device = create_encrypted_device()
+    device._encryption_mode = AESMode.GCM
+    device._iv = initial_iv
+
+    with pytest.raises(expected_exception, match=expected_message):
+        device._increment_gcm_iv()
+
+
+@pytest.mark.asyncio
+async def test_gcm_encrypt_decrypt_without_finalize() -> None:
+    """Test GCM encrypt/decrypt works without finalize in decrypt."""
+    device = create_encrypted_device()
+    device._encryption_mode = AESMode.GCM
+    device._iv = b"\x10" * 12
+
+    ciphertext_hex, _ = device._encrypt("48656c6c6f")
+    decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex))
+
+    assert decrypted.hex() == "48656c6c6f"
+
+
 @pytest.mark.asyncio
 async def test_ensure_encryption_initialized_failure() -> None:
     """Test _ensure_encryption_initialized when IV initialization fails."""
@@ -233,12 +392,13 @@ async def test_encrypt_decrypt_with_valid_iv() -> None:
     device._iv = b"\x00" * 16  # Use zeros for predictable test
 
     # Test encryption
-    encrypted = device._encrypt("48656c6c6f")  # "Hello" in hex
-    assert isinstance(encrypted, str)
-    assert len(encrypted) > 0
+    ciphertext_hex, header_hex = device._encrypt("48656c6c6f")  # "Hello" in hex
+    assert isinstance(ciphertext_hex, str)
+    assert isinstance(header_hex, str)
+    assert len(ciphertext_hex) > 0
 
     # Test decryption
-    decrypted = device._decrypt(bytearray.fromhex(encrypted))
+    decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex))
     assert decrypted.hex() == "48656c6c6f"
 
 
@@ -278,6 +438,7 @@ async def test_execute_disconnect_clears_encryption_state() -> None:
     device = create_encrypted_device()
     device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0"
     device._cipher = None  # type: ignore[assignment]
+    device._encryption_mode = AESMode.CTR
 
     # Mock client
     mock_client = AsyncMock()
@@ -288,6 +449,7 @@ async def test_execute_disconnect_clears_encryption_state() -> None:
 
     assert device._iv is None
     assert device._cipher is None
+    assert device._encryption_mode is None
     mock_disconnect.assert_called_once()
 
 
@@ -304,7 +466,7 @@ async def test_concurrent_commands_with_same_device() -> None:
         patch.object(device, "_encrypt") as mock_encrypt,
         patch.object(device, "_decrypt") as mock_decrypt,
     ):
-        mock_encrypt.return_value = "656e63727970746564"  # "encrypted" in hex
+        mock_encrypt.return_value = ("656e63727970746564", "abcd")
         mock_decrypt.return_value = b"response"
         mock_send.return_value = b"\x01\x00\x00\x00data"
 
@@ -337,7 +499,7 @@ async def test_command_retry_with_encryption() -> None:
         patch.object(device, "_encrypt") as mock_encrypt,
         patch.object(device, "_decrypt") as mock_decrypt,
     ):
-        mock_encrypt.return_value = "656e63727970746564"  # "encrypted" in hex
+        mock_encrypt.return_value = ("656e63727970746564", "abcd")
         mock_decrypt.return_value = b"response"
 
         # First attempt fails, second succeeds
@@ -360,7 +522,7 @@ async def test_empty_data_encryption_decryption() -> None:
 
     # Test empty encryption
     encrypted = device._encrypt("")
-    assert encrypted == ""
+    assert encrypted == ("", "")
 
     # Test empty decryption
     decrypted = device._decrypt(bytearray())