Browse Source

Hold a strong ref to the timed disconnect task to avoid GC (#189)

J. Nick Koston 1 year ago
parent
commit
979cdc9bc8
4 changed files with 21 additions and 4 deletions
  1. 1 0
      switchbot/adv_parsers/humidifier.py
  2. 0 1
      switchbot/const.py
  3. 3 3
      switchbot/devices/device.py
  4. 17 0
      switchbot/util.py

+ 1 - 0
switchbot/adv_parsers/humidifier.py

@@ -10,6 +10,7 @@ _LOGGER = logging.getLogger(__name__)
 # data: 650000cd802b6300
 # data: 658000c9802b6300
 
+
 # Low:  658000c5222b6300
 # Med:  658000c5432b6300
 # High: 658000c5642b6300

+ 0 - 1
switchbot/const.py

@@ -27,7 +27,6 @@ class SwitchbotAccountConnectionError(RuntimeError):
 
 
 class SwitchbotModel(StrEnum):
-
     BOT = "WoHand"
     CURTAIN = "WoCurtain"
     HUMIDIFIER = "WoHumi"

+ 3 - 3
switchbot/devices/device.py

@@ -26,6 +26,7 @@ from bleak_retry_connector import (
 from ..const import DEFAULT_RETRY_COUNT, DEFAULT_SCAN_TIMEOUT
 from ..discovery import GetSwitchbotDevices
 from ..models import SwitchBotAdvertisement
+from ..util import execute_task
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -48,7 +49,6 @@ DISCONNECT_DELAY = 8.5
 
 
 class ColorMode(Enum):
-
     OFF = 0
     COLOR_TEMP = 1
     RGB = 2
@@ -331,7 +331,7 @@ class SwitchbotBaseDevice:
             self._reset_disconnect_timer()
             return
         self._cancel_disconnect_timer()
-        asyncio.create_task(self._execute_timed_disconnect())
+        execute_task(self._execute_timed_disconnect())
 
     def _cancel_disconnect_timer(self):
         """Cancel disconnect timer."""
@@ -375,7 +375,7 @@ class SwitchbotBaseDevice:
         self._client = None
         self._read_char = None
         self._write_char = None
-        if client and client.is_connected:
+        if client:
             _LOGGER.debug("%s: Disconnecting", self.name)
             await client.disconnect()
             _LOGGER.debug("%s: Disconnect completed", self.name)

+ 17 - 0
switchbot/util.py

@@ -0,0 +1,17 @@
+"""Library to handle connection with Switchbot."""
+
+import asyncio
+from collections.abc import Awaitable
+from typing import Any
+
+
+def execute_task(fut: Awaitable[Any]) -> None:
+    """Execute task."""
+    task = asyncio.create_task(fut)
+    tasks = [task]
+
+    def _cleanup_task(task: asyncio.Task[Any]) -> None:
+        """Cleanup task."""
+        tasks.remove(task)
+
+    task.add_done_callback(_cleanup_task)