Browse Source

Implement a retry on disconnect during transaction (#75)

J. Nick Koston 1 year ago
parent
commit
b8e394191d

+ 1 - 0
switchbot/adv_parsers/bulb.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 def process_color_bulb(data: bytes, mfr_data: bytes | None) -> dict[str, bool | int]:
     """Process WoBulb services data."""
+    assert mfr_data is not None
     return {
         "sequence_number": mfr_data[6],
         "isOn": bool(mfr_data[7] & 0b10000000),

+ 3 - 1
switchbot/adv_parsers/meter.py

@@ -1,8 +1,10 @@
 """Meter parser."""
 from __future__ import annotations
 
+from typing import Any
 
-def process_wosensorth(data: bytes, mfr_data: bytes | None) -> dict[str, object]:
+
+def process_wosensorth(data: bytes, mfr_data: bytes | None) -> dict[str, Any]:
     """Process woSensorTH/Temp sensor services data."""
 
     _temp_sign = 1 if data[4] & 0b10000000 else -1

+ 1 - 0
switchbot/adv_parsers/plug.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 def process_woplugmini(data: bytes, mfr_data: bytes | None) -> dict[str, bool | int]:
     """Process plug mini."""
+    assert mfr_data is not None
     return {
         "switchMode": True,
         "isOn": mfr_data[7] == 0x80,

+ 9 - 5
switchbot/devices/bot.py

@@ -38,7 +38,9 @@ class Switchbot(SwitchbotDevice):
             return True
 
         if result[0] == 5:
-            _LOGGER.debug("Bot is in press mode and doesn't have on state")
+            _LOGGER.debug(
+                "%s: Bot is in press mode and doesn't have on state", self.name
+            )
             return True
 
         return False
@@ -50,7 +52,9 @@ class Switchbot(SwitchbotDevice):
             return True
 
         if result[0] == 5:
-            _LOGGER.debug("Bot is in press mode and doesn't have off state")
+            _LOGGER.debug(
+                "%s: Bot is in press mode and doesn't have off state", self.name
+            )
             return True
 
         return False
@@ -62,7 +66,7 @@ class Switchbot(SwitchbotDevice):
             return True
 
         if result[0] == 5:
-            _LOGGER.debug("Bot is in press mode")
+            _LOGGER.debug("%s: Bot is in press mode", self.name)
             return True
 
         return False
@@ -74,7 +78,7 @@ class Switchbot(SwitchbotDevice):
             return True
 
         if result[0] == 5:
-            _LOGGER.debug("Bot is in press mode")
+            _LOGGER.debug("%s: Bot is in press mode", self.name)
             return True
 
         return False
@@ -86,7 +90,7 @@ class Switchbot(SwitchbotDevice):
             return True
 
         if result[0] == 5:
-            _LOGGER.debug("Bot is in switch mode")
+            _LOGGER.debug("%s: Bot is in switch mode", self.name)
             return True
 
         return False

+ 2 - 2
switchbot/devices/curtain.py

@@ -112,7 +112,7 @@ class SwitchbotCurtain(SwitchbotDevice):
         )
 
         if _data in (b"\x07", b"\x00"):
-            _LOGGER.error("Unsuccessfull, please try again")
+            _LOGGER.error("%s: Unsuccessful, please try again", self.name)
             return None
 
         self.ext_info_sum["device0"] = {
@@ -145,7 +145,7 @@ class SwitchbotCurtain(SwitchbotDevice):
         )
 
         if _data in (b"\x07", b"\x00"):
-            _LOGGER.error("Unsuccessfull, please try again")
+            _LOGGER.error("%s: Unsuccessful, please try again", self.name)
             return None
 
         _state_of_charge = [

+ 105 - 39
switchbot/devices/device.py

@@ -4,13 +4,15 @@ from __future__ import annotations
 import asyncio
 import binascii
 import logging
-from typing import Any
+from ctypes import cast
+from typing import Any, Callable, TypeVar
 from uuid import UUID
 
 import async_timeout
 import bleak
+from bleak import BleakError
 from bleak.backends.device import BLEDevice
-from bleak.backends.service import BleakGATTServiceCollection
+from bleak.backends.service import BleakGATTCharacteristic, BleakGATTServiceCollection
 from bleak_retry_connector import (
     BleakClientWithServiceCache,
     ble_device_has_changed,
@@ -31,6 +33,13 @@ DEVICE_SET_EXTENDED_KEY = "570f"
 # Base key when encryption is set
 KEY_PASSWORD_PREFIX = "571"
 
+BLEAK_EXCEPTIONS = (AttributeError, BleakError, asyncio.exceptions.TimeoutError)
+
+# How long to hold the connection
+# to wait for additional commands for
+# disconnecting the device.
+DISCONNECT_DELAY = 59
+
 
 def _sb_uuid(comms_type: str = "service") -> UUID | str:
     """Return Switchbot UUID."""
@@ -60,13 +69,19 @@ class SwitchbotDevice:
         self._scan_timeout: int = kwargs.pop("scan_timeout", DEFAULT_SCAN_TIMEOUT)
         self._retry_count: int = kwargs.pop("retry_count", DEFAULT_RETRY_COUNT)
         self._connect_lock = asyncio.Lock()
+        self._operation_lock = asyncio.Lock()
         if password is None or password == "":
             self._password_encoded = None
         else:
             self._password_encoded = "%08x" % (
                 binascii.crc32(password.encode("ascii")) & 0xFFFFFFFF
             )
+        self._client: BleakClientWithServiceCache | None = None
         self._cached_services: BleakGATTServiceCollection | None = None
+        self._read_char: BleakGATTCharacteristic | None = None
+        self._write_char: BleakGATTCharacteristic | None = None
+        self._disconnect_timer: asyncio.TimerHandle | None = None
+        self.loop = asyncio.get_event_loop()
 
     def _commandkey(self, key: str) -> str:
         """Add password to key if set."""
@@ -79,21 +94,30 @@ class SwitchbotDevice:
     async def _sendcommand(self, key: str, retry: int) -> bytes:
         """Send command to device and read response."""
         command = bytearray.fromhex(self._commandkey(key))
-        _LOGGER.debug("Sending command to switchbot %s", command)
+        _LOGGER.debug("%s: Sending command %s", self.name, command)
+        if self._operation_lock.locked():
+            _LOGGER.debug(
+                "%s: Operation already in progress, waiting for it to complete.",
+                self.name,
+            )
+
         max_attempts = retry + 1
-        async with self._connect_lock:
+        async with self._operation_lock:
             for attempt in range(max_attempts):
                 try:
                     return await self._send_command_locked(key, command)
-                except (bleak.BleakError, asyncio.exceptions.TimeoutError):
+                except BLEAK_EXCEPTIONS:
                     if attempt == retry:
                         _LOGGER.error(
-                            "Switchbot communication failed. Stopping trying",
+                            "%s: communication failed. Stopping trying",
+                            self.name,
                             exc_info=True,
                         )
                         return b"\x00"
 
-                    _LOGGER.debug("Switchbot communication failed with:", exc_info=True)
+                    _LOGGER.debug(
+                        "%s: communication failed with:", self.name, exc_info=True
+                    )
 
         raise RuntimeError("Unreachable")
 
@@ -102,49 +126,91 @@ class SwitchbotDevice:
         """Return device name."""
         return f"{self._device.name} ({self._device.address})"
 
-    async def _send_command_locked(self, key: str, command: bytes) -> bytes:
-        """Send command to device and read response."""
-        client: BleakClientWithServiceCache | None = None
-        try:
-            _LOGGER.debug("%s: Connnecting to switchbot", self.name)
+    async def _ensure_connected(self):
+        """Ensure connection to device is established."""
+        if self._connect_lock.locked():
+            _LOGGER.debug(
+                "%s: Connection already in progress, waiting for it to complete.",
+                self.name,
+            )
+        if self._client and self._client.is_connected:
+            self._reset_disconnect_timer()
+            return
+        async with self._connect_lock:
+            # Check again while holding the lock
+            if self._client and self._client.is_connected:
+                self._reset_disconnect_timer()
+                return
             client = await establish_connection(
                 BleakClientWithServiceCache,
                 self._device,
                 self.name,
-                max_attempts=1,
                 cached_services=self._cached_services,
             )
             self._cached_services = client.services
-            _LOGGER.debug(
-                "%s: Connnected to switchbot: %s", self.name, client.is_connected
-            )
-            read_char = client.services.get_characteristic(_sb_uuid(comms_type="rx"))
-            write_char = client.services.get_characteristic(_sb_uuid(comms_type="tx"))
-            future: asyncio.Future[bytearray] = asyncio.Future()
+            _LOGGER.debug("%s: Connected", self.name)
+            services = client.services
+            self._read_char = services.get_characteristic(_sb_uuid(comms_type="rx"))
+            self._write_char = services.get_characteristic(_sb_uuid(comms_type="tx"))
+            self._client = client
+            self._reset_disconnect_timer()
+
+    def _reset_disconnect_timer(self):
+        """Reset disconnect timer."""
+        if self._disconnect_timer:
+            self._disconnect_timer.cancel()
+        self._disconnect_timer = self.loop.call_later(
+            DISCONNECT_DELAY, self._disconnect
+        )
 
-            def _notification_handler(_sender: int, data: bytearray) -> None:
-                """Handle notification responses."""
-                if future.done():
-                    _LOGGER.debug("%s: Notification handler already done", self.name)
-                    return
-                future.set_result(data)
+    def _disconnect(self):
+        """Disconnect from device."""
+        self._disconnect_timer = None
+        asyncio.create_task(self._execute_disconnect())
+
+    async def _execute_disconnect(self):
+        """Execute disconnection."""
+        _LOGGER.debug(
+            "%s: Disconnecting after timeout of %s",
+            self.name,
+            DISCONNECT_DELAY,
+        )
+        async with self._connect_lock:
+            if not self._client or not self._client.is_connected:
+                return
+            await self._client.disconnect()
+            self._client = None
+            self._read_char = None
+            self._write_char = None
+
+    async def _send_command_locked(self, key: str, command: bytes) -> bytes:
+        """Send command to device and read response."""
+        await self._ensure_connected()
+        assert self._client is not None
+        assert self._read_char is not None
+        assert self._write_char is not None
+        future: asyncio.Future[bytearray] = asyncio.Future()
+        client = self._client
 
-            _LOGGER.debug("%s: Subscribe to notifications", self.name)
-            await client.start_notify(read_char, _notification_handler)
+        def _notification_handler(_sender: int, data: bytearray) -> None:
+            """Handle notification responses."""
+            if future.done():
+                _LOGGER.debug("%s: Notification handler already done", self.name)
+                return
+            future.set_result(data)
 
-            _LOGGER.debug("%s: Sending command, %s", self.name, key)
-            await client.write_gatt_char(write_char, command, False)
+        _LOGGER.debug("%s: Subscribe to notifications", self.name)
+        await client.start_notify(self._read_char, _notification_handler)
 
-            async with async_timeout.timeout(5):
-                notify_msg = await future
-            _LOGGER.info("%s: Notification received: %s", self.name, notify_msg)
+        _LOGGER.debug("%s: Sending command: %s", self.name, key)
+        await client.write_gatt_char(self._write_char, command, False)
 
-            _LOGGER.debug("%s: UnSubscribe to notifications", self.name)
-            await client.stop_notify(read_char)
+        async with async_timeout.timeout(5):
+            notify_msg = await future
+        _LOGGER.debug("%s: Notification received: %s", self.name, notify_msg)
 
-        finally:
-            if client:
-                await client.disconnect()
+        _LOGGER.debug("%s: UnSubscribe to notifications", self.name)
+        await client.stop_notify(self._read_char)
 
         if notify_msg == b"\x07":
             _LOGGER.error("Password required")
@@ -175,7 +241,7 @@ class SwitchbotDevice:
 
     async def get_device_data(
         self, retry: int = DEFAULT_RETRY_COUNT, interface: int | None = None
-    ) -> dict | None:
+    ) -> SwitchBotAdvertisement | None:
         """Find switchbot devices and their advertisement data."""
         if interface:
             _interface: int = interface
@@ -191,7 +257,7 @@ class SwitchbotDevice:
 
         return self._sb_adv_data
 
-    async def _get_basic_info(self) -> dict | None:
+    async def _get_basic_info(self) -> bytes | None:
         """Return basic info of device."""
         _data = await self._sendcommand(
             key=DEVICE_GET_BASIC_SETTINGS_KEY, retry=self._retry_count

+ 1 - 9
switchbot/devices/plug.py

@@ -13,11 +13,6 @@ PLUG_OFF_KEY = "570f50010100"
 class SwitchbotPlugMini(SwitchbotDevice):
     """Representation of a Switchbot plug mini."""
 
-    def __init__(self, *args: Any, **kwargs: Any) -> None:
-        """Switchbot plug mini constructor."""
-        super().__init__(*args, **kwargs)
-        self._settings: dict[str, Any] = {}
-
     async def update(self, interface: int | None = None) -> None:
         """Update state of device."""
         await self.get_device_data(retry=self._retry_count, interface=interface)
@@ -35,7 +30,4 @@ class SwitchbotPlugMini(SwitchbotDevice):
     def is_on(self) -> Any:
         """Return switch state from cache."""
         # To get actual position call update() first.
-        value = self._get_adv_value("isOn")
-        if value is None:
-            return None
-        return value
+        return self._get_adv_value("isOn")

+ 3 - 3
switchbot/discovery.py

@@ -109,8 +109,8 @@ class GetSwitchbotDevices:
             await self.discover()
 
         return {
-            device: data
-            for device, data in self._adv_data.items()
+            device: adv
+            for device, adv in self._adv_data.items()
             # MacOS uses UUIDs instead of MAC addresses
-            if data.get("address") == address
+            if adv.data.get("address") == address
         }