Forráskód Böngészése

Implement a retry on disconnect during transaction

J. Nick Koston 2 éve
szülő
commit
bfcff41bc7
1 módosított fájl, 48 hozzáadás és 34 törlés
  1. 48 34
      switchbot/devices/device.py

+ 48 - 34
switchbot/devices/device.py

@@ -4,13 +4,15 @@ from __future__ import annotations
 import asyncio
 import binascii
 import logging
-from typing import Any
+from ctypes import cast
+from typing import Any, Callable, TypeVar
 from uuid import UUID
 
 import async_timeout
 import bleak
+from bleak import BleakError
 from bleak.backends.device import BLEDevice
-from bleak.backends.service import BleakGATTServiceCollection
+from bleak.backends.service import BleakGATTCharacteristic, BleakGATTServiceCollection
 from bleak_retry_connector import (
     BleakClientWithServiceCache,
     ble_device_has_changed,
@@ -31,6 +33,8 @@ DEVICE_SET_EXTENDED_KEY = "570f"
 # Base key when encryption is set
 KEY_PASSWORD_PREFIX = "571"
 
+BLEAK_EXCEPTIONS = (AttributeError, BleakError, asyncio.exceptions.TimeoutError)
+
 
 def _sb_uuid(comms_type: str = "service") -> UUID | str:
     """Return Switchbot UUID."""
@@ -60,13 +64,17 @@ class SwitchbotDevice:
         self._scan_timeout: int = kwargs.pop("scan_timeout", DEFAULT_SCAN_TIMEOUT)
         self._retry_count: int = kwargs.pop("retry_count", DEFAULT_RETRY_COUNT)
         self._connect_lock = asyncio.Lock()
+        self._operation_lock = asyncio.Lock()
         if password is None or password == "":
             self._password_encoded = None
         else:
             self._password_encoded = "%08x" % (
                 binascii.crc32(password.encode("ascii")) & 0xFFFFFFFF
             )
+        self._client: BleakClientWithServiceCache | None = None
         self._cached_services: BleakGATTServiceCollection | None = None
+        self._read_char: BleakGATTCharacteristic | None = None
+        self._write_char: BleakGATTCharacteristic | None = None
 
     def _commandkey(self, key: str) -> str:
         """Add password to key if set."""
@@ -80,12 +88,15 @@ class SwitchbotDevice:
         """Send command to device and read response."""
         command = bytearray.fromhex(self._commandkey(key))
         _LOGGER.debug("Sending command to switchbot %s", command)
+
         max_attempts = retry + 1
-        async with self._connect_lock:
+        async with self._operation_lock:
             for attempt in range(max_attempts):
                 try:
                     return await self._send_command_locked(key, command)
-                except (bleak.BleakError, asyncio.exceptions.TimeoutError):
+                except BLEAK_EXCEPTIONS(
+                    bleak.BleakError, asyncio.exceptions.TimeoutError
+                ):
                     if attempt == retry:
                         _LOGGER.error(
                             "Switchbot communication failed. Stopping trying",
@@ -102,49 +113,52 @@ class SwitchbotDevice:
         """Return device name."""
         return f"{self._device.name} ({self._device.address})"
 
-    async def _send_command_locked(self, key: str, command: bytes) -> bytes:
-        """Send command to device and read response."""
-        client: BleakClientWithServiceCache | None = None
-        try:
-            _LOGGER.debug("%s: Connnecting to switchbot", self.name)
+    async def _ensure_connected(self):
+        if self._client and self._client.is_connected:
+            return
+        async with self._connect_lock:
+            # Check again while holding the lock
+            if self._client and self._client.is_connected:
+                return
             client = await establish_connection(
                 BleakClientWithServiceCache,
                 self._device,
                 self.name,
-                max_attempts=1,
                 cached_services=self._cached_services,
             )
             self._cached_services = client.services
-            _LOGGER.debug(
-                "%s: Connnected to switchbot: %s", self.name, client.is_connected
-            )
-            read_char = client.services.get_characteristic(_sb_uuid(comms_type="rx"))
-            write_char = client.services.get_characteristic(_sb_uuid(comms_type="tx"))
-            future: asyncio.Future[bytearray] = asyncio.Future()
+            _LOGGER.debug("%s: Connected to SwitchBot Device", self.name)
+            services = client.services
+            self._read_char = services.get_characteristic(_sb_uuid(comms_type="rx"))
+            self._write_char = services.get_characteristic(_sb_uuid(comms_type="tx"))
+            self._client = client
 
-            def _notification_handler(_sender: int, data: bytearray) -> None:
-                """Handle notification responses."""
-                if future.done():
-                    _LOGGER.debug("%s: Notification handler already done", self.name)
-                    return
-                future.set_result(data)
+    async def _send_command_locked(self, key: str, command: bytes) -> bytes:
+        """Send command to device and read response."""
+        client: BleakClientWithServiceCache | None = None
+        await self._ensure_connected()
+        client = self._client
+        future: asyncio.Future[bytearray] = asyncio.Future()
 
-            _LOGGER.debug("%s: Subscribe to notifications", self.name)
-            await client.start_notify(read_char, _notification_handler)
+        def _notification_handler(_sender: int, data: bytearray) -> None:
+            """Handle notification responses."""
+            if future.done():
+                _LOGGER.debug("%s: Notification handler already done", self.name)
+                return
+            future.set_result(data)
 
-            _LOGGER.debug("%s: Sending command, %s", self.name, key)
-            await client.write_gatt_char(write_char, command, False)
+        _LOGGER.debug("%s: Subscribe to notifications", self.name)
+        await client.start_notify(self._read_char, _notification_handler)
 
-            async with async_timeout.timeout(5):
-                notify_msg = await future
-            _LOGGER.info("%s: Notification received: %s", self.name, notify_msg)
+        _LOGGER.debug("%s: Sending command, %s", self.name, key)
+        await client.write_gatt_char(self._write_char, command, False)
 
-            _LOGGER.debug("%s: UnSubscribe to notifications", self.name)
-            await client.stop_notify(read_char)
+        async with async_timeout.timeout(5):
+            notify_msg = await future
+        _LOGGER.info("%s: Notification received: %s", self.name, notify_msg)
 
-        finally:
-            if client:
-                await client.disconnect()
+        _LOGGER.debug("%s: UnSubscribe to notifications", self.name)
+        await client.stop_notify(self._read_char)
 
         if notify_msg == b"\x07":
             _LOGGER.error("Password required")