Browse Source

Only setup the notify handler once (#116)

J. Nick Koston 1 year ago
parent
commit
78d2b116a2
1 changed files with 17 additions and 15 deletions
  1. 17 15
      switchbot/devices/device.py

+ 17 - 15
switchbot/devices/device.py

@@ -113,6 +113,7 @@ class SwitchbotBaseDevice:
         self._expected_disconnect = False
         self.loop = asyncio.get_event_loop()
         self._callbacks: list[Callable[[], None]] = []
+        self._notify_future: asyncio.Future[bytearray] | None = None
 
     def advertisement_changed(self, advertisement: SwitchBotAdvertisement) -> bool:
         """Check if the advertisement has changed."""
@@ -238,6 +239,7 @@ class SwitchbotBaseDevice:
                 resolved = self._resolve_characteristics(await client.get_services())
             self._client = client
             self._reset_disconnect_timer()
+            await self._start_notify()
 
     def _resolve_characteristics(self, services: BleakGATTServiceCollection) -> bool:
         """Resolve characteristics."""
@@ -335,6 +337,18 @@ class SwitchbotBaseDevice:
             await self._execute_forced_disconnect()
             raise
 
+    def _notification_handler(self, _sender: int, data: bytearray) -> None:
+        """Handle notification responses."""
+        if self._notify_future and not self._notify_future.done():
+            self._notify_future.set_result(data)
+            return
+        _LOGGER.debug("%s: Received unsolicited notification: %s", self.name, data)
+
+    async def _start_notify(self) -> None:
+        """Start notification."""
+        _LOGGER.debug("%s: Subscribe to notifications; RSSI: %s", self.name, self.rssi)
+        await self._client.start_notify(self._read_char, self._notification_handler)
+
     async def _execute_command_locked(self, key: str, command: bytes) -> bytes:
         """Execute command and read response."""
         assert self._client is not None
@@ -342,28 +356,16 @@ class SwitchbotBaseDevice:
             raise CharacteristicMissingError(READ_CHAR_UUID)
         if not self._write_char:
             raise CharacteristicMissingError(WRITE_CHAR_UUID)
-        future: asyncio.Future[bytearray] = asyncio.Future()
+        self._notify_future = asyncio.Future()
         client = self._client
 
-        def _notification_handler(_sender: int, data: bytearray) -> None:
-            """Handle notification responses."""
-            if future.done():
-                _LOGGER.debug("%s: Notification handler already done", self.name)
-                return
-            future.set_result(data)
-
-        _LOGGER.debug("%s: Subscribe to notifications; RSSI: %s", self.name, self.rssi)
-        await client.start_notify(self._read_char, _notification_handler)
-
         _LOGGER.debug("%s: Sending command: %s", self.name, key)
         await client.write_gatt_char(self._write_char, command, False)
 
         async with async_timeout.timeout(5):
-            notify_msg = await future
+            notify_msg = await self._notify_future
         _LOGGER.debug("%s: Notification received: %s", self.name, notify_msg)
-
-        _LOGGER.debug("%s: UnSubscribe to notifications", self.name)
-        await client.stop_notify(self._read_char)
+        self._notify_future = None
 
         if notify_msg == b"\x07":
             _LOGGER.error("Password required")