瀏覽代碼

Extract aiohttp.ClientSession as method parameter

Damian Sypniewski 10 月之前
父節點
當前提交
ecf611809c
共有 1 個文件被更改,包括 29 次插入20 次删除
  1. 29 20
      switchbot/devices/lock.py

+ 29 - 20
switchbot/devices/lock.py

@@ -84,34 +84,42 @@ class SwitchbotLock(SwitchbotDevice):
         return lock_info is not None
 
     @staticmethod
-    async def api_request(subdomain: str, path: str, data: dict = None, headers: dict = None):
-        async with aiohttp.ClientSession(headers=headers, timeout=aiohttp.ClientTimeout(total=10)) as session:
-            url = f"https://{subdomain}.{SWITCHBOT_APP_API_BASE_URL}/{path}"
-            async with session.post(url, json=data) as result:
-                if result.status > 299:
-                    raise SwitchbotApiError(
-                        f"Unexpected status code returned by SwitchBot API: {result.status}"
-                    )
-
-                response = await result.json()
-                if response["statusCode"] != 100:
-                    raise SwitchbotApiError(
-                        f"{response['message']}, status code: {response['statusCode']}"
-                    )
-
-                return response["body"]
-
+    async def api_request(
+            session: aiohttp.ClientSession, subdomain: str, path: str, data: dict = None, headers: dict = None
+    ) -> dict:
+        url = f"https://{subdomain}.{SWITCHBOT_APP_API_BASE_URL}/{path}"
+        async with session.post(url, json=data, headers=headers) as result:
+            if result.status > 299:
+                raise SwitchbotApiError(
+                    f"Unexpected status code returned by SwitchBot API: {result.status}"
+                )
+
+            response = await result.json()
+            if response["statusCode"] != 100:
+                raise SwitchbotApiError(
+                    f"{response['message']}, status code: {response['statusCode']}"
+                )
+
+            return response["body"]
+
+    # Old non-async method preserved for backwards compatibility
     @staticmethod
     def retrieve_encryption_key(device_mac: str, username: str, password: str):
-        return asyncio.run(SwitchbotLock.async_retrieve_encryption_key(device_mac, username, password))
+        async def async_fn():
+            async with aiohttp.ClientSession() as session:
+                return await SwitchbotLock.async_retrieve_encryption_key(session, device_mac, username, password)
+        return asyncio.run(async_fn())
 
     @staticmethod
-    async def async_retrieve_encryption_key(device_mac: str, username: str, password: str):
+    async def async_retrieve_encryption_key(
+            session: aiohttp.ClientSession, device_mac: str, username: str, password: str
+    ) -> dict:
         """Retrieve lock key from internal SwitchBot API."""
         device_mac = device_mac.replace(":", "").replace("-", "").upper()
 
         try:
             auth_result = await SwitchbotLock.api_request(
+                session,
                 "account",
                 "account/api/v1/user/login",
                 {
@@ -128,7 +136,7 @@ class SwitchbotLock(SwitchbotDevice):
 
         try:
             userinfo = await SwitchbotLock.api_request(
-                "account", "account/api/v1/user/userinfo", {}, auth_headers
+                session, "account", "account/api/v1/user/userinfo", {}, auth_headers
             )
             if "botRegion" in userinfo and userinfo["botRegion"] != "":
                 region = userinfo["botRegion"]
@@ -141,6 +149,7 @@ class SwitchbotLock(SwitchbotDevice):
 
         try:
             device_info = await SwitchbotLock.api_request(
+                session,
                 f"wonderlabs.{region}",
                 "wonder/keys/v1/communicate",
                 {