Browse Source

Extract common methods for encrypted devices into the base class (#297)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: J. Nick Koston <nick@koston.org>
Damian Sypniewski 1 week ago
parent
commit
e1ac5a3ec8
3 changed files with 56 additions and 112 deletions
  1. 56 0
      switchbot/devices/device.py
  2. 0 57
      switchbot/devices/lock.py
  3. 0 55
      switchbot/devices/relay_switch.py

+ 56 - 0
switchbot/devices/device.py

@@ -10,6 +10,7 @@ from dataclasses import replace
 from enum import Enum
 from typing import Any, TypeVar, cast
 from collections.abc import Callable
+from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
 from uuid import UUID
 
 import aiohttp
@@ -45,6 +46,7 @@ REQ_HEADER = "570f"
 DEVICE_GET_BASIC_SETTINGS_KEY = "5702"
 DEVICE_SET_MODE_KEY = "5703"
 DEVICE_SET_EXTENDED_KEY = REQ_HEADER
+COMMAND_GET_CK_IV = f"{REQ_HEADER}2103"
 
 # Base key when encryption is set
 KEY_PASSWORD_PREFIX = "571"
@@ -827,6 +829,60 @@ class SwitchbotEncryptedDevice(SwitchbotDevice):
 
         return info is not None
 
+    async def _send_command(
+        self, key: str, retry: int | None = None, encrypt: bool = True
+    ) -> bytes | None:
+        if not encrypt:
+            return await super()._send_command(key[:2] + "000000" + key[2:], retry)
+
+        result = await self._ensure_encryption_initialized()
+        if not result:
+            _LOGGER.error("Failed to initialize encryption")
+            return None
+
+        encrypted = (
+            key[:2] + self._key_id + self._iv[0:2].hex() + self._encrypt(key[2:])
+        )
+        result = await super()._send_command(encrypted, retry)
+        return result[:1] + self._decrypt(result[4:])
+
+    async def _ensure_encryption_initialized(self) -> bool:
+        if self._iv is not None:
+            return True
+
+        result = await self._send_command(
+            COMMAND_GET_CK_IV + self._key_id, encrypt=False
+        )
+        ok = self._check_command_result(result, 0, {1})
+        if ok:
+            self._iv = result[4:]
+
+        return ok
+
+    async def _execute_disconnect(self) -> None:
+        await super()._execute_disconnect()
+        self._iv = None
+        self._cipher = None
+
+    def _get_cipher(self) -> Cipher:
+        if self._cipher is None:
+            self._cipher = Cipher(
+                algorithms.AES128(self._encryption_key), modes.CTR(self._iv)
+            )
+        return self._cipher
+
+    def _encrypt(self, data: str) -> str:
+        if len(data) == 0:
+            return ""
+        encryptor = self._get_cipher().encryptor()
+        return (encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()).hex()
+
+    def _decrypt(self, data: bytearray) -> bytes:
+        if len(data) == 0:
+            return b""
+        decryptor = self._get_cipher().decryptor()
+        return decryptor.update(data) + decryptor.finalize()
+
 
 class SwitchbotDeviceOverrideStateDuringConnection(SwitchbotBaseDevice):
     """Base Representation of a Switchbot Device.

+ 0 - 57
switchbot/devices/lock.py

@@ -7,13 +7,11 @@ import time
 from typing import Any
 
 from bleak.backends.device import BLEDevice
-from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
 
 from ..const import LockStatus, SwitchbotModel
 from .device import SwitchbotEncryptedDevice
 
 COMMAND_HEADER = "57"
-COMMAND_GET_CK_IV = f"{COMMAND_HEADER}0f2103"
 COMMAND_LOCK_INFO = {
     SwitchbotModel.LOCK: f"{COMMAND_HEADER}0f4f8101",
     SwitchbotModel.LOCK_PRO: f"{COMMAND_HEADER}0f4f8102",
@@ -220,58 +218,3 @@ class SwitchbotLock(SwitchbotEncryptedDevice):
             "unclosed_alarm": bool(data[1] & 0b00100000),
             "unlocked_alarm": bool(data[1] & 0b00010000),
         }
-
-    async def _send_command(
-        self, key: str, retry: int | None = None, encrypt: bool = True
-    ) -> bytes | None:
-        if not encrypt:
-            return await super()._send_command(key[:2] + "000000" + key[2:], retry)
-
-        result = await self._ensure_encryption_initialized()
-        if not result:
-            _LOGGER.error("Failed to initialize encryption")
-            return None
-
-        encrypted = (
-            key[:2] + self._key_id + self._iv[0:2].hex() + self._encrypt(key[2:])
-        )
-        result = await super()._send_command(encrypted, retry)
-        return result[:1] + self._decrypt(result[4:])
-
-    async def _ensure_encryption_initialized(self) -> bool:
-        if self._iv is not None:
-            return True
-
-        result = await self._send_command(
-            COMMAND_GET_CK_IV + self._key_id, encrypt=False
-        )
-        ok = self._check_command_result(result, 0, COMMAND_RESULT_EXPECTED_VALUES)
-        if ok:
-            self._iv = result[4:]
-
-        return ok
-
-    async def _execute_disconnect(self) -> None:
-        await super()._execute_disconnect()
-        self._iv = None
-        self._cipher = None
-        self._notifications_enabled = False
-
-    def _get_cipher(self) -> Cipher:
-        if self._cipher is None:
-            self._cipher = Cipher(
-                algorithms.AES128(self._encryption_key), modes.CTR(self._iv)
-            )
-        return self._cipher
-
-    def _encrypt(self, data: str) -> str:
-        if len(data) == 0:
-            return ""
-        encryptor = self._get_cipher().encryptor()
-        return (encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()).hex()
-
-    def _decrypt(self, data: bytearray) -> bytes:
-        if len(data) == 0:
-            return b""
-        decryptor = self._get_cipher().decryptor()
-        return decryptor.update(data) + decryptor.finalize()

+ 0 - 55
switchbot/devices/relay_switch.py

@@ -3,7 +3,6 @@ import time
 from typing import Any
 
 from bleak.backends.device import BLEDevice
-from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
 
 from ..const import SwitchbotModel
 from ..models import SwitchBotAdvertisement
@@ -12,7 +11,6 @@ from .device import SwitchbotEncryptedDevice
 _LOGGER = logging.getLogger(__name__)
 
 COMMAND_HEADER = "57"
-COMMAND_GET_CK_IV = f"{COMMAND_HEADER}0f2103"
 COMMAND_TURN_OFF = f"{COMMAND_HEADER}0f70010000"
 COMMAND_TURN_ON = f"{COMMAND_HEADER}0f70010100"
 COMMAND_TOGGLE = f"{COMMAND_HEADER}0f70010200"
@@ -139,56 +137,3 @@ class SwitchbotRelaySwitch(SwitchbotEncryptedDevice):
     def is_on(self) -> bool | None:
         """Return switch state from cache."""
         return self._get_adv_value("isOn")
-
-    async def _send_command(
-        self, key: str, retry: int | None = None, encrypt: bool = True
-    ) -> bytes | None:
-        if not encrypt:
-            return await super()._send_command(key[:2] + "000000" + key[2:], retry)
-
-        result = await self._ensure_encryption_initialized()
-        if not result:
-            return None
-
-        encrypted = (
-            key[:2] + self._key_id + self._iv[0:2].hex() + self._encrypt(key[2:])
-        )
-        result = await super()._send_command(encrypted, retry)
-        return result[:1] + self._decrypt(result[4:])
-
-    async def _ensure_encryption_initialized(self) -> bool:
-        if self._iv is not None:
-            return True
-
-        result = await self._send_command(
-            COMMAND_GET_CK_IV + self._key_id, encrypt=False
-        )
-        ok = self._check_command_result(result, 0, {1})
-        if ok:
-            self._iv = result[4:]
-
-        return ok
-
-    async def _execute_disconnect(self) -> None:
-        await super()._execute_disconnect()
-        self._iv = None
-        self._cipher = None
-
-    def _get_cipher(self) -> Cipher:
-        if self._cipher is None:
-            self._cipher = Cipher(
-                algorithms.AES128(self._encryption_key), modes.CTR(self._iv)
-            )
-        return self._cipher
-
-    def _encrypt(self, data: str) -> str:
-        if len(data) == 0:
-            return ""
-        encryptor = self._get_cipher().encryptor()
-        return (encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()).hex()
-
-    def _decrypt(self, data: bytearray) -> bytes:
-        if len(data) == 0:
-            return b""
-        decryptor = self._get_cipher().decryptor()
-        return decryptor.update(data) + decryptor.finalize()