Browse Source

Refactor to support push updates (#49)

- Fixes the assumption that address were mac addresses
  on MacOS they are UUIDs

- Adds the ability to consume an update from the scanner
  that is always running
J. Nick Koston 1 year ago
parent
commit
6b68f96500
1 changed files with 109 additions and 103 deletions
  1. 109 103
      switchbot/__init__.py

+ 109 - 103
switchbot/__init__.py

@@ -4,10 +4,13 @@ from __future__ import annotations
 import asyncio
 import binascii
 import logging
+from dataclasses import dataclass
 from typing import Any
 from uuid import UUID
 
 import bleak
+from bleak.backends.device import BLEDevice
+from bleak.backends.scanner import AdvertisementData
 
 DEFAULT_RETRY_COUNT = 3
 DEFAULT_RETRY_TIMEOUT = 1
@@ -100,54 +103,71 @@ def _process_wosensorth(data: bytes) -> dict[str, object]:
     return _wosensorth_data
 
 
+@dataclass
+class SwitchBotAdvertisement:
+    """Switchbot advertisement."""
+
+    address: str
+    data: dict[str, Any]
+    device: BLEDevice
+
+
+def parse_advertisement_data(
+    device: BLEDevice, advertisement_data: AdvertisementData
+) -> SwitchBotAdvertisement | None:
+    """Parse advertisement data."""
+    _services = list(advertisement_data.service_data.values())
+    if not _services:
+        return
+    _service_data = _services[0]
+    _model = chr(_service_data[0] & 0b01111111)
+
+    supported_types: dict[str, dict[str, Any]] = {
+        "H": {"modelName": "WoHand", "func": _process_wohand},
+        "c": {"modelName": "WoCurtain", "func": _process_wocurtain},
+        "T": {"modelName": "WoSensorTH", "func": _process_wosensorth},
+    }
+
+    data = {
+        "address": device.address,  # MacOS uses UUIDs
+        "rawAdvData": list(advertisement_data.service_data.values())[0],
+        "data": {
+            "rssi": device.rssi,
+        },
+    }
+
+    if _model in supported_types:
+
+        data.update(
+            {
+                "isEncrypted": bool(_service_data[0] & 0b10000000),
+                "model": _model,
+                "modelName": supported_types[_model]["modelName"],
+                "data": supported_types[_model]["func"](_service_data),
+            }
+        )
+
+        data["data"]["rssi"] = device.rssi
+
+    return SwitchBotAdvertisement(device.address, data, device)
+
+
 class GetSwitchbotDevices:
     """Scan for all Switchbot devices and return by type."""
 
     def __init__(self, interface: int = 0) -> None:
         """Get switchbot devices class constructor."""
         self._interface = f"hci{interface}"
-        self._adv_data: dict[str, Any] = {}
+        self._adv_data: dict[str, SwitchBotAdvertisement] = {}
 
     def detection_callback(
         self,
-        device: bleak.backends.device.BLEDevice,
-        advertisement_data: bleak.backends.scanner.AdvertisementData,
+        device: BLEDevice,
+        advertisement_data: AdvertisementData,
     ) -> None:
-        """BTLE adv scan callback."""
-        _services = list(advertisement_data.service_data.values())
-        if not _services:
-            return
-        _service_data = _services[0]
-
-        _device = device.address.replace(":", "").lower()
-        _model = chr(_service_data[0] & 0b01111111)
-
-        supported_types: dict[str, dict[str, Any]] = {
-            "H": {"modelName": "WoHand", "func": _process_wohand},
-            "c": {"modelName": "WoCurtain", "func": _process_wocurtain},
-            "T": {"modelName": "WoSensorTH", "func": _process_wosensorth},
-        }
-
-        self._adv_data[_device] = {
-            "mac_address": device.address.lower(),
-            "rawAdvData": list(advertisement_data.service_data.values())[0],
-            "data": {
-                "rssi": device.rssi,
-            },
-        }
-
-        if _model in supported_types:
-
-            self._adv_data[_device].update(
-                {
-                    "isEncrypted": bool(_service_data[0] & 0b10000000),
-                    "model": _model,
-                    "modelName": supported_types[_model]["modelName"],
-                    "data": supported_types[_model]["func"](_service_data),
-                }
-            )
-
-            self._adv_data[_device]["data"]["rssi"] = device.rssi
+        discovery = parse_advertisement_data(device, advertisement_data)
+        if discovery:
+            self._adv_data[discovery.address] = discovery
 
     async def discover(
         self, retry: int = DEFAULT_RETRY_COUNT, scan_timeout: int = DEFAULT_SCAN_TIMEOUT
@@ -155,7 +175,6 @@ class GetSwitchbotDevices:
         """Find switchbot devices and their advertisement data."""
 
         devices = None
-
         devices = bleak.BleakScanner(
             # TODO: Find new UUIDs to filter on. For example, see
             # https://github.com/OpenWonderLabs/SwitchBotAPI-BLE/blob/4ad138bb09f0fbbfa41b152ca327a78c1d0b6ba9/devicetypes/meter.md
@@ -184,46 +203,35 @@ class GetSwitchbotDevices:
 
         return self._adv_data
 
-    async def get_curtains(self) -> dict:
-        """Return all WoCurtain/Curtains devices with services data."""
+    async def _get_devices_by_model(
+        self,
+        model: str,
+    ) -> dict:
+        """Get switchbot devices by type."""
         if not self._adv_data:
             await self.discover()
 
-        _curtain_devices = {
-            device: data
-            for device, data in self._adv_data.items()
-            if data.get("model") == "c"
+        return {
+            address: adv
+            for address, adv in self._adv_data.items()
+            if adv.data.get("model") == model
         }
 
-        return _curtain_devices
+    async def get_curtains(self) -> dict[str, SwitchBotAdvertisement]:
+        """Return all WoCurtain/Curtains devices with services data."""
+        return await self._get_devices_by_model("c")
 
-    async def get_bots(self) -> dict[str, Any] | None:
+    async def get_bots(self) -> dict[str, SwitchBotAdvertisement]:
         """Return all WoHand/Bot devices with services data."""
-        if not self._adv_data:
-            await self.discover()
-
-        _bot_devices = {
-            device: data
-            for device, data in self._adv_data.items()
-            if data.get("model") == "H"
-        }
-
-        return _bot_devices
+        return await self._get_devices_by_model("H")
 
-    async def get_tempsensors(self) -> dict[str, Any] | None:
+    async def get_tempsensors(self) -> dict[str, SwitchBotAdvertisement]:
         """Return all WoSensorTH/Temp sensor devices with services data."""
-        if not self._adv_data:
-            await self.discover()
-
-        _bot_temp = {
-            device: data
-            for device, data in self._adv_data.items()
-            if data.get("model") == "T"
-        }
-
-        return _bot_temp
+        return await self._get_devices_by_model("T")
 
-    async def get_device_data(self, mac: str) -> dict[str, Any] | None:
+    async def get_device_data(
+        self, address: str
+    ) -> dict[str, SwitchBotAdvertisement] | None:
         """Return data for specific device."""
         if not self._adv_data:
             await self.discover()
@@ -231,7 +239,8 @@ class GetSwitchbotDevices:
         _switchbot_data = {
             device: data
             for device, data in self._adv_data.items()
-            if data.get("mac_address") == mac
+            # MacOS uses UUIDs instead of MAC addresses
+            if data.get("address") == address
         }
 
         return _switchbot_data
@@ -242,15 +251,15 @@ class SwitchbotDevice:
 
     def __init__(
         self,
-        mac: str,
+        device: BLEDevice,
         password: str | None = None,
         interface: int = 0,
         **kwargs: Any,
     ) -> None:
         """Switchbot base class constructor."""
         self._interface = f"hci{interface}"
-        self._mac = mac.replace("-", ":").lower()
-        self._sb_adv_data: dict[str, Any] = {}
+        self._device = device
+        self._sb_adv_data: SwitchBotAdvertisement | None = None
         self._scan_timeout: int = kwargs.pop("scan_timeout", DEFAULT_SCAN_TIMEOUT)
         self._retry_count: int = kwargs.pop("retry_count", DEFAULT_RETRY_COUNT)
         if password is None or password == "":
@@ -279,13 +288,11 @@ class SwitchbotDevice:
         notify_msg = b""
         _LOGGER.debug("Sending command to switchbot %s", command)
 
-        if len(self._mac.split(":")) != 6:
-            raise ValueError("Expected MAC address, got %s" % repr(self._mac))
-
         async with CONNECT_LOCK:
             try:
                 async with bleak.BleakClient(
-                    address_or_ble_device=self._mac, timeout=float(self._scan_timeout)
+                    address_or_ble_device=self._device,
+                    timeout=float(self._scan_timeout),
                 ) as client:
                     _LOGGER.debug("Connnected to switchbot: %s", client.is_connected)
 
@@ -334,15 +341,24 @@ class SwitchbotDevice:
         await asyncio.sleep(DEFAULT_RETRY_TIMEOUT)
         return await self._sendcommand(key, retry - 1)
 
-    def get_mac(self) -> str:
-        """Return mac address of device."""
-        return self._mac
+    def get_address(self) -> str:
+        """Return address of device."""
+        return self._device.address
 
-    def get_battery_percent(self) -> Any:
-        """Return device battery level in percent."""
+    def _get_adv_value(self, key: str) -> Any:
+        """Return value from advertisement data."""
         if not self._sb_adv_data:
             return None
-        return self._sb_adv_data["data"]["battery"]
+        return self._sb_adv_data.data["data"][key]
+
+    def get_battery_percent(self) -> Any:
+        """Return device battery level in percent."""
+        return self._get_adv_value("battery")
+
+    def update_from_advertisement(self, advertisement: SwitchBotAdvertisement) -> None:
+        """Update device data from advertisement."""
+        self._sb_adv_data = advertisement
+        self._device = advertisement.device
 
     async def get_device_data(
         self, retry: int = DEFAULT_RETRY_COUNT, interface: int | None = None
@@ -353,14 +369,12 @@ class SwitchbotDevice:
         else:
             _interface = int(self._interface.replace("hci", ""))
 
-        dev_id = self._mac.replace(":", "")
-
         _data = await GetSwitchbotDevices(interface=_interface).discover(
             retry=retry, scan_timeout=self._scan_timeout
         )
 
-        if _data.get(dev_id):
-            self._sb_adv_data = _data[dev_id]
+        if self._device.address in _data:
+            self._sb_adv_data = _data[self._device.address]
 
         return self._sb_adv_data
 
@@ -493,20 +507,18 @@ class Switchbot(SwitchbotDevice):
     def switch_mode(self) -> Any:
         """Return true or false from cache."""
         # To get actual position call update() first.
-        if not self._sb_adv_data.get("data"):
-            return None
-        return self._sb_adv_data["data"].get("switchMode")
+        return self._get_adv_value("switchMode")
 
     def is_on(self) -> Any:
         """Return switch state from cache."""
         # To get actual position call update() first.
-        if not self._sb_adv_data.get("data"):
+        value = self._get_adv_value("isOn")
+        if value is None:
             return None
 
         if self._inverse:
-            return not self._sb_adv_data["data"].get("isOn")
-
-        return self._sb_adv_data["data"].get("isOn")
+            return not value
+        return value
 
 
 class SwitchbotCurtain(SwitchbotDevice):
@@ -570,9 +582,7 @@ class SwitchbotCurtain(SwitchbotDevice):
     def get_position(self) -> Any:
         """Return cached position (0-100) of Curtain."""
         # To get actual position call update() first.
-        if not self._sb_adv_data.get("data"):
-            return None
-        return self._sb_adv_data["data"].get("position")
+        return self._get_adv_value("position")
 
     async def get_basic_info(self) -> dict[str, Any] | None:
         """Get device basic settings."""
@@ -676,9 +686,7 @@ class SwitchbotCurtain(SwitchbotDevice):
     def get_light_level(self) -> Any:
         """Return cached light level."""
         # To get actual light level call update() first.
-        if not self._sb_adv_data.get("data"):
-            return None
-        return self._sb_adv_data["data"].get("lightLevel")
+        return self._get_adv_value("lightLevel")
 
     def is_reversed(self) -> bool:
         """Return True if curtain position is opposite from SB data."""
@@ -687,6 +695,4 @@ class SwitchbotCurtain(SwitchbotDevice):
     def is_calibrated(self) -> Any:
         """Return True curtain is calibrated."""
         # To get actual light level call update() first.
-        if not self._sb_adv_data.get("data"):
-            return None
-        return self._sb_adv_data["data"].get("calibration")
+        return self._get_adv_value("calibration")