|
@@ -4,13 +4,15 @@ from __future__ import annotations
|
|
|
import asyncio
|
|
|
import binascii
|
|
|
import logging
|
|
|
-from typing import Any
|
|
|
+from ctypes import cast
|
|
|
+from typing import Any, Callable, TypeVar
|
|
|
from uuid import UUID
|
|
|
|
|
|
import async_timeout
|
|
|
import bleak
|
|
|
+from bleak import BleakError
|
|
|
from bleak.backends.device import BLEDevice
|
|
|
-from bleak.backends.service import BleakGATTServiceCollection
|
|
|
+from bleak.backends.service import BleakGATTCharacteristic, BleakGATTServiceCollection
|
|
|
from bleak_retry_connector import (
|
|
|
BleakClientWithServiceCache,
|
|
|
ble_device_has_changed,
|
|
@@ -31,6 +33,8 @@ DEVICE_SET_EXTENDED_KEY = "570f"
|
|
|
# Base key when encryption is set
|
|
|
KEY_PASSWORD_PREFIX = "571"
|
|
|
|
|
|
+BLEAK_EXCEPTIONS = (AttributeError, BleakError, asyncio.exceptions.TimeoutError)
|
|
|
+
|
|
|
|
|
|
def _sb_uuid(comms_type: str = "service") -> UUID | str:
|
|
|
"""Return Switchbot UUID."""
|
|
@@ -60,13 +64,17 @@ class SwitchbotDevice:
|
|
|
self._scan_timeout: int = kwargs.pop("scan_timeout", DEFAULT_SCAN_TIMEOUT)
|
|
|
self._retry_count: int = kwargs.pop("retry_count", DEFAULT_RETRY_COUNT)
|
|
|
self._connect_lock = asyncio.Lock()
|
|
|
+ self._operation_lock = asyncio.Lock()
|
|
|
if password is None or password == "":
|
|
|
self._password_encoded = None
|
|
|
else:
|
|
|
self._password_encoded = "%08x" % (
|
|
|
binascii.crc32(password.encode("ascii")) & 0xFFFFFFFF
|
|
|
)
|
|
|
+ self._client: BleakClientWithServiceCache | None = None
|
|
|
self._cached_services: BleakGATTServiceCollection | None = None
|
|
|
+ self._read_char: BleakGATTCharacteristic | None = None
|
|
|
+ self._write_char: BleakGATTCharacteristic | None = None
|
|
|
|
|
|
def _commandkey(self, key: str) -> str:
|
|
|
"""Add password to key if set."""
|
|
@@ -80,12 +88,15 @@ class SwitchbotDevice:
|
|
|
"""Send command to device and read response."""
|
|
|
command = bytearray.fromhex(self._commandkey(key))
|
|
|
_LOGGER.debug("Sending command to switchbot %s", command)
|
|
|
+
|
|
|
max_attempts = retry + 1
|
|
|
- async with self._connect_lock:
|
|
|
+ async with self._operation_lock:
|
|
|
for attempt in range(max_attempts):
|
|
|
try:
|
|
|
return await self._send_command_locked(key, command)
|
|
|
- except (bleak.BleakError, asyncio.exceptions.TimeoutError):
|
|
|
+ except BLEAK_EXCEPTIONS(
|
|
|
+ bleak.BleakError, asyncio.exceptions.TimeoutError
|
|
|
+ ):
|
|
|
if attempt == retry:
|
|
|
_LOGGER.error(
|
|
|
"Switchbot communication failed. Stopping trying",
|
|
@@ -102,49 +113,52 @@ class SwitchbotDevice:
|
|
|
"""Return device name."""
|
|
|
return f"{self._device.name} ({self._device.address})"
|
|
|
|
|
|
- async def _send_command_locked(self, key: str, command: bytes) -> bytes:
|
|
|
- """Send command to device and read response."""
|
|
|
- client: BleakClientWithServiceCache | None = None
|
|
|
- try:
|
|
|
- _LOGGER.debug("%s: Connnecting to switchbot", self.name)
|
|
|
+ async def _ensure_connected(self):
|
|
|
+ if self._client and self._client.is_connected:
|
|
|
+ return
|
|
|
+ async with self._connect_lock:
|
|
|
+ # Check again while holding the lock
|
|
|
+ if self._client and self._client.is_connected:
|
|
|
+ return
|
|
|
client = await establish_connection(
|
|
|
BleakClientWithServiceCache,
|
|
|
self._device,
|
|
|
self.name,
|
|
|
- max_attempts=1,
|
|
|
cached_services=self._cached_services,
|
|
|
)
|
|
|
self._cached_services = client.services
|
|
|
- _LOGGER.debug(
|
|
|
- "%s: Connnected to switchbot: %s", self.name, client.is_connected
|
|
|
- )
|
|
|
- read_char = client.services.get_characteristic(_sb_uuid(comms_type="rx"))
|
|
|
- write_char = client.services.get_characteristic(_sb_uuid(comms_type="tx"))
|
|
|
- future: asyncio.Future[bytearray] = asyncio.Future()
|
|
|
+ _LOGGER.debug("%s: Connected to SwitchBot Device", self.name)
|
|
|
+ services = client.services
|
|
|
+ self._read_char = services.get_characteristic(_sb_uuid(comms_type="rx"))
|
|
|
+ self._write_char = services.get_characteristic(_sb_uuid(comms_type="tx"))
|
|
|
+ self._client = client
|
|
|
|
|
|
- def _notification_handler(_sender: int, data: bytearray) -> None:
|
|
|
- """Handle notification responses."""
|
|
|
- if future.done():
|
|
|
- _LOGGER.debug("%s: Notification handler already done", self.name)
|
|
|
- return
|
|
|
- future.set_result(data)
|
|
|
+ async def _send_command_locked(self, key: str, command: bytes) -> bytes:
|
|
|
+ """Send command to device and read response."""
|
|
|
+ client: BleakClientWithServiceCache | None = None
|
|
|
+ await self._ensure_connected()
|
|
|
+ client = self._client
|
|
|
+ future: asyncio.Future[bytearray] = asyncio.Future()
|
|
|
|
|
|
- _LOGGER.debug("%s: Subscribe to notifications", self.name)
|
|
|
- await client.start_notify(read_char, _notification_handler)
|
|
|
+ def _notification_handler(_sender: int, data: bytearray) -> None:
|
|
|
+ """Handle notification responses."""
|
|
|
+ if future.done():
|
|
|
+ _LOGGER.debug("%s: Notification handler already done", self.name)
|
|
|
+ return
|
|
|
+ future.set_result(data)
|
|
|
|
|
|
- _LOGGER.debug("%s: Sending command, %s", self.name, key)
|
|
|
- await client.write_gatt_char(write_char, command, False)
|
|
|
+ _LOGGER.debug("%s: Subscribe to notifications", self.name)
|
|
|
+ await client.start_notify(self._read_char, _notification_handler)
|
|
|
|
|
|
- async with async_timeout.timeout(5):
|
|
|
- notify_msg = await future
|
|
|
- _LOGGER.info("%s: Notification received: %s", self.name, notify_msg)
|
|
|
+ _LOGGER.debug("%s: Sending command, %s", self.name, key)
|
|
|
+ await client.write_gatt_char(self._write_char, command, False)
|
|
|
|
|
|
- _LOGGER.debug("%s: UnSubscribe to notifications", self.name)
|
|
|
- await client.stop_notify(read_char)
|
|
|
+ async with async_timeout.timeout(5):
|
|
|
+ notify_msg = await future
|
|
|
+ _LOGGER.info("%s: Notification received: %s", self.name, notify_msg)
|
|
|
|
|
|
- finally:
|
|
|
- if client:
|
|
|
- await client.disconnect()
|
|
|
+ _LOGGER.debug("%s: UnSubscribe to notifications", self.name)
|
|
|
+ await client.stop_notify(self._read_char)
|
|
|
|
|
|
if notify_msg == b"\x07":
|
|
|
_LOGGER.error("Password required")
|