1
0
Эх сурвалжийг харах

Move encryption and api functions into the base class (#277)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
J. Nick Koston 3 сар өмнө
parent
commit
9f939ce74e

+ 163 - 1
switchbot/devices/device.py

@@ -12,6 +12,7 @@ from typing import Any, TypeVar, cast
 from collections.abc import Callable
 from uuid import UUID
 
+import aiohttp
 from bleak.backends.device import BLEDevice
 from bleak.backends.service import BleakGATTCharacteristic, BleakGATTServiceCollection
 from bleak.exc import BleakDBusError
@@ -23,7 +24,15 @@ from bleak_retry_connector import (
     establish_connection,
 )
 
-from ..const import DEFAULT_RETRY_COUNT, DEFAULT_SCAN_TIMEOUT
+from ..api_config import SWITCHBOT_APP_API_BASE_URL, SWITCHBOT_APP_CLIENT_ID
+from ..const import (
+    DEFAULT_RETRY_COUNT,
+    DEFAULT_SCAN_TIMEOUT,
+    SwitchbotAccountConnectionError,
+    SwitchbotApiError,
+    SwitchbotAuthenticationError,
+    SwitchbotModel,
+)
 from ..discovery import GetSwitchbotDevices
 from ..models import SwitchBotAdvertisement
 
@@ -152,6 +161,35 @@ class SwitchbotBaseDevice:
         self._last_full_update: float = -PASSIVE_POLL_INTERVAL
         self._timed_disconnect_task: asyncio.Task[None] | None = None
 
+    @classmethod
+    async def api_request(
+        cls,
+        session: aiohttp.ClientSession,
+        subdomain: str,
+        path: str,
+        data: dict = None,
+        headers: dict = None,
+    ) -> dict:
+        url = f"https://{subdomain}.{SWITCHBOT_APP_API_BASE_URL}/{path}"
+        async with session.post(
+            url,
+            json=data,
+            headers=headers,
+            timeout=aiohttp.ClientTimeout(total=10),
+        ) as result:
+            if result.status > 299:
+                raise SwitchbotApiError(
+                    f"Unexpected status code returned by SwitchBot API: {result.status}"
+                )
+
+            response = await result.json()
+            if response["statusCode"] != 100:
+                raise SwitchbotApiError(
+                    f"{response['message']}, status code: {response['statusCode']}"
+                )
+
+            return response["body"]
+
     def advertisement_changed(self, advertisement: SwitchBotAdvertisement) -> bool:
         """Check if the advertisement has changed."""
         return bool(
@@ -666,6 +704,130 @@ class SwitchbotDevice(SwitchbotBaseDevice):
         self._set_advertisement_data(advertisement)
 
 
+class SwitchbotEncryptedDevice(SwitchbotDevice):
+    """A Switchbot device that uses encryption."""
+
+    def __init__(
+        self,
+        device: BLEDevice,
+        key_id: str,
+        encryption_key: str,
+        model: SwitchbotModel,
+        interface: int = 0,
+        **kwargs: Any,
+    ) -> None:
+        """Switchbot base class constructor for encrypted devices."""
+        if len(key_id) == 0:
+            raise ValueError("key_id is missing")
+        elif len(key_id) != 2:
+            raise ValueError("key_id is invalid")
+        if len(encryption_key) == 0:
+            raise ValueError("encryption_key is missing")
+        elif len(encryption_key) != 32:
+            raise ValueError("encryption_key is invalid")
+        self._key_id = key_id
+        self._encryption_key = bytearray.fromhex(encryption_key)
+        self._iv: bytes | None = None
+        self._cipher: bytes | None = None
+        self._model = model
+        super().__init__(device, None, interface, **kwargs)
+
+    # Old non-async method preserved for backwards compatibility
+    @classmethod
+    def retrieve_encryption_key(cls, device_mac: str, username: str, password: str):
+        async def async_fn():
+            async with aiohttp.ClientSession() as session:
+                return await cls.async_retrieve_encryption_key(
+                    session, device_mac, username, password
+                )
+
+        return asyncio.run(async_fn())
+
+    @classmethod
+    async def async_retrieve_encryption_key(
+        cls,
+        session: aiohttp.ClientSession,
+        device_mac: str,
+        username: str,
+        password: str,
+    ) -> dict:
+        """Retrieve lock key from internal SwitchBot API."""
+        device_mac = device_mac.replace(":", "").replace("-", "").upper()
+
+        try:
+            auth_result = await cls.api_request(
+                session,
+                "account",
+                "account/api/v1/user/login",
+                {
+                    "clientId": SWITCHBOT_APP_CLIENT_ID,
+                    "username": username,
+                    "password": password,
+                    "grantType": "password",
+                    "verifyCode": "",
+                },
+            )
+            auth_headers = {"authorization": auth_result["access_token"]}
+        except Exception as err:
+            raise SwitchbotAuthenticationError(f"Authentication failed: {err}") from err
+
+        try:
+            userinfo = await cls.api_request(
+                session, "account", "account/api/v1/user/userinfo", {}, auth_headers
+            )
+            if "botRegion" in userinfo and userinfo["botRegion"] != "":
+                region = userinfo["botRegion"]
+            else:
+                region = "us"
+        except Exception as err:
+            raise SwitchbotAccountConnectionError(
+                f"Failed to retrieve SwitchBot Account user details: {err}"
+            ) from err
+
+        try:
+            device_info = await cls.api_request(
+                session,
+                f"wonderlabs.{region}",
+                "wonder/keys/v1/communicate",
+                {
+                    "device_mac": device_mac,
+                    "keyType": "user",
+                },
+                auth_headers,
+            )
+
+            return {
+                "key_id": device_info["communicationKey"]["keyId"],
+                "encryption_key": device_info["communicationKey"]["key"],
+            }
+        except Exception as err:
+            raise SwitchbotAccountConnectionError(
+                f"Failed to retrieve encryption key from SwitchBot Account: {err}"
+            ) from err
+
+    @classmethod
+    async def verify_encryption_key(
+        cls,
+        device: BLEDevice,
+        key_id: str,
+        encryption_key: str,
+        model: SwitchbotModel,
+        **kwargs: Any,
+    ) -> bool:
+        try:
+            device = cls(
+                device, key_id=key_id, encryption_key=encryption_key, model=model
+            )
+        except ValueError:
+            return False
+        try:
+            info = await device.get_basic_info()
+        except SwitchbotOperationError:
+            return False
+
+        return info is not None
+
+
 class SwitchbotDeviceOverrideStateDuringConnection(SwitchbotBaseDevice):
     """Base Representation of a Switchbot Device.
 

+ 9 - 136
switchbot/devices/lock.py

@@ -2,24 +2,15 @@
 
 from __future__ import annotations
 
-import asyncio
 import logging
 import time
 from typing import Any
 
-import aiohttp
 from bleak.backends.device import BLEDevice
 from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
 
-from ..api_config import SWITCHBOT_APP_API_BASE_URL, SWITCHBOT_APP_CLIENT_ID
-from ..const import (
-    LockStatus,
-    SwitchbotAccountConnectionError,
-    SwitchbotApiError,
-    SwitchbotAuthenticationError,
-    SwitchbotModel,
-)
-from .device import SwitchbotDevice, SwitchbotOperationError
+from ..const import LockStatus, SwitchbotModel
+from .device import SwitchbotEncryptedDevice
 
 COMMAND_HEADER = "57"
 COMMAND_GET_CK_IV = f"{COMMAND_HEADER}0f2103"
@@ -54,7 +45,7 @@ COMMAND_RESULT_EXPECTED_VALUES = {1, 6}
 # The return value of the command is 6 when the command is successful but the battery is low.
 
 
-class SwitchbotLock(SwitchbotDevice):
+class SwitchbotLock(SwitchbotEncryptedDevice):
     """Representation of a Switchbot Lock."""
 
     def __init__(
@@ -66,141 +57,23 @@ class SwitchbotLock(SwitchbotDevice):
         model: SwitchbotModel = SwitchbotModel.LOCK,
         **kwargs: Any,
     ) -> None:
-        if len(key_id) == 0:
-            raise ValueError("key_id is missing")
-        elif len(key_id) != 2:
-            raise ValueError("key_id is invalid")
-        if len(encryption_key) == 0:
-            raise ValueError("encryption_key is missing")
-        elif len(encryption_key) != 32:
-            raise ValueError("encryption_key is invalid")
         if model not in (SwitchbotModel.LOCK, SwitchbotModel.LOCK_PRO):
             raise ValueError("initializing SwitchbotLock with a non-lock model")
-        self._iv = None
-        self._cipher = None
-        self._key_id = key_id
-        self._encryption_key = bytearray.fromhex(encryption_key)
         self._notifications_enabled: bool = False
-        self._model: SwitchbotModel = model
-        super().__init__(device, None, interface, **kwargs)
+        super().__init__(device, key_id, encryption_key, model, interface, **kwargs)
 
-    @staticmethod
+    @classmethod
     async def verify_encryption_key(
+        cls,
         device: BLEDevice,
         key_id: str,
         encryption_key: str,
         model: SwitchbotModel = SwitchbotModel.LOCK,
         **kwargs: Any,
     ) -> bool:
-        try:
-            lock = SwitchbotLock(
-                device, key_id=key_id, encryption_key=encryption_key, model=model
-            )
-        except ValueError:
-            return False
-        try:
-            lock_info = await lock.get_basic_info()
-        except SwitchbotOperationError:
-            return False
-
-        return lock_info is not None
-
-    @staticmethod
-    async def api_request(
-        session: aiohttp.ClientSession,
-        subdomain: str,
-        path: str,
-        data: dict = None,
-        headers: dict = None,
-    ) -> dict:
-        url = f"https://{subdomain}.{SWITCHBOT_APP_API_BASE_URL}/{path}"
-        async with session.post(
-            url,
-            json=data,
-            headers=headers,
-            timeout=aiohttp.ClientTimeout(total=10),
-        ) as result:
-            if result.status > 299:
-                raise SwitchbotApiError(
-                    f"Unexpected status code returned by SwitchBot API: {result.status}"
-                )
-
-            response = await result.json()
-            if response["statusCode"] != 100:
-                raise SwitchbotApiError(
-                    f"{response['message']}, status code: {response['statusCode']}"
-                )
-
-            return response["body"]
-
-    # Old non-async method preserved for backwards compatibility
-    @staticmethod
-    def retrieve_encryption_key(device_mac: str, username: str, password: str):
-        async def async_fn():
-            async with aiohttp.ClientSession() as session:
-                return await SwitchbotLock.async_retrieve_encryption_key(
-                    session, device_mac, username, password
-                )
-
-        return asyncio.run(async_fn())
-
-    @staticmethod
-    async def async_retrieve_encryption_key(
-        session: aiohttp.ClientSession, device_mac: str, username: str, password: str
-    ) -> dict:
-        """Retrieve lock key from internal SwitchBot API."""
-        device_mac = device_mac.replace(":", "").replace("-", "").upper()
-
-        try:
-            auth_result = await SwitchbotLock.api_request(
-                session,
-                "account",
-                "account/api/v1/user/login",
-                {
-                    "clientId": SWITCHBOT_APP_CLIENT_ID,
-                    "username": username,
-                    "password": password,
-                    "grantType": "password",
-                    "verifyCode": "",
-                },
-            )
-            auth_headers = {"authorization": auth_result["access_token"]}
-        except Exception as err:
-            raise SwitchbotAuthenticationError(f"Authentication failed: {err}") from err
-
-        try:
-            userinfo = await SwitchbotLock.api_request(
-                session, "account", "account/api/v1/user/userinfo", {}, auth_headers
-            )
-            if "botRegion" in userinfo and userinfo["botRegion"] != "":
-                region = userinfo["botRegion"]
-            else:
-                region = "us"
-        except Exception as err:
-            raise SwitchbotAccountConnectionError(
-                f"Failed to retrieve SwitchBot Account user details: {err}"
-            ) from err
-
-        try:
-            device_info = await SwitchbotLock.api_request(
-                session,
-                f"wonderlabs.{region}",
-                "wonder/keys/v1/communicate",
-                {
-                    "device_mac": device_mac,
-                    "keyType": "user",
-                },
-                auth_headers,
-            )
-
-            return {
-                "key_id": device_info["communicationKey"]["keyId"],
-                "encryption_key": device_info["communicationKey"]["key"],
-            }
-        except Exception as err:
-            raise SwitchbotAccountConnectionError(
-                f"Failed to retrieve encryption key from SwitchBot Account: {err}"
-            ) from err
+        return super().verify_encryption_key(
+            device, key_id, encryption_key, model, **kwargs
+        )
 
     async def lock(self) -> bool:
         """Send lock command."""

+ 16 - 16
switchbot/devices/relay_switch.py

@@ -7,7 +7,7 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
 
 from ..const import SwitchbotModel
 from ..models import SwitchBotAdvertisement
-from .device import SwitchbotDevice
+from .device import SwitchbotEncryptedDevice
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -20,7 +20,7 @@ COMMAND_GET_VOLTAGE_AND_CURRENT = f"{COMMAND_HEADER}0f7106000000"
 PASSIVE_POLL_INTERVAL = 10 * 60
 
 
-class SwitchbotRelaySwitch(SwitchbotDevice):
+class SwitchbotRelaySwitch(SwitchbotEncryptedDevice):
     """Representation of a Switchbot relay switch 1pm."""
 
     def __init__(
@@ -32,21 +32,21 @@ class SwitchbotRelaySwitch(SwitchbotDevice):
         model: SwitchbotModel = SwitchbotModel.RELAY_SWITCH_1PM,
         **kwargs: Any,
     ) -> None:
-        if len(key_id) == 0:
-            raise ValueError("key_id is missing")
-        elif len(key_id) != 2:
-            raise ValueError("key_id is invalid")
-        if len(encryption_key) == 0:
-            raise ValueError("encryption_key is missing")
-        elif len(encryption_key) != 32:
-            raise ValueError("encryption_key is invalid")
-        self._iv = None
-        self._cipher = None
-        self._key_id = key_id
-        self._encryption_key = bytearray.fromhex(encryption_key)
-        self._model: SwitchbotModel = model
         self._force_next_update = False
-        super().__init__(device, None, interface, **kwargs)
+        super().__init__(device, key_id, encryption_key, model, interface, **kwargs)
+
+    @classmethod
+    async def verify_encryption_key(
+        cls,
+        device: BLEDevice,
+        key_id: str,
+        encryption_key: str,
+        model: SwitchbotModel = SwitchbotModel.RELAY_SWITCH_1PM,
+        **kwargs: Any,
+    ) -> bool:
+        return super().verify_encryption_key(
+            device, key_id, encryption_key, model, **kwargs
+        )
 
     def update_from_advertisement(self, advertisement: SwitchBotAdvertisement) -> None:
         """Update device data from advertisement."""