Browse Source

replace paho-mqtt with its async wrapper aiomqtt (to prepare for upgrading PySwitchbot)

https://github.com/fphammerle/switchbot-mqtt/issues/103
https://github.com/fphammerle/switchbot-mqtt/issues/180#issuecomment-1741108146
https://github.com/fphammerle/switchbot-mqtt/issues/127#issuecomment-1349244614
Fabian Peter Hammerle 6 months ago
parent
commit
52764c2695

+ 4 - 0
CHANGELOG.md

@@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 ### Added
 - declare compatibility with `python3.11`
 
+### Changed
+- replaced [paho-mqtt](https://github.com/eclipse/paho.mqtt.python)
+  with its async wrapper [aiomqtt](https://github.com/sbtinstruments/aiomqtt)
+
 ### Removed
 - compatibility with `python3.7`
 

+ 1 - 0
Pipfile

@@ -11,6 +11,7 @@ black = "*"
 mypy = "*"
 pylint = "*"
 pytest = "*"
+pytest-asyncio = "*"
 pytest-cov = "*"
 
 # python3.10 compatibility

+ 18 - 1
Pipfile.lock

@@ -1,7 +1,7 @@
 {
     "_meta": {
         "hash": {
-            "sha256": "cdf039b4e2e188227f3d34852c1dfbb5449901c9b4bc10a8379b46eebe84fb64"
+            "sha256": "94ad3eac5fb437c0e4a9fe45f316b813bcbc809b0cfc901ba1d885ae2c44fc67"
         },
         "pipfile-spec": 6,
         "requires": {
@@ -16,6 +16,14 @@
         ]
     },
     "default": {
+        "aiomqtt": {
+            "hashes": [
+                "sha256:3925b40b2b95b1905753d53ef3a9162e903cfab35ebe9647ab4d52e45ffb727f",
+                "sha256:7582f4341f08ef7110dd9ab3a559454dc28ccda1eac502ff8f08a73b238ecede"
+            ],
+            "markers": "python_version >= '3.8' and python_version < '4.0'",
+            "version": "==1.2.1"
+        },
         "bluepy": {
             "hashes": [
                 "sha256:2a71edafe103565fb990256ff3624c1653036a837dfc90e1e32b839f83971cec"
@@ -275,6 +283,15 @@
             "markers": "python_version >= '3.7'",
             "version": "==7.4.3"
         },
+        "pytest-asyncio": {
+            "hashes": [
+                "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d",
+                "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"
+            ],
+            "index": "pypi",
+            "markers": "python_version >= '3.7'",
+            "version": "==0.21.1"
+        },
         "pytest-cov": {
             "hashes": [
                 "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6",

+ 4 - 3
setup.py

@@ -73,8 +73,9 @@ setuptools.setup(
     ],
     entry_points={"console_scripts": ["switchbot-mqtt = switchbot_mqtt._cli:_main"]},
     # >=3.6 variable type hints, f-strings, typing.Collection & * to force keyword-only arguments
-    # >=3.7 postponed evaluation of type annotations (PEP563) & dataclass
-    python_requires=">=3.8",  # python<3.8 untested
+    # >=3.7 postponed evaluation of type annotations (PEP563) & asyncio.run
+    # >=3.8 unittest.mock.AsyncMock
+    python_requires=">=3.8",
     install_requires=[
         # >=1.3.0 for btle.BTLEManagementError (could be replaced with BTLEException)
         # >=0.1.0 for btle.helperExe
@@ -83,7 +84,7 @@ setuptools.setup(
         # >=0.10.0 for SwitchbotCurtain.{update,get_position}
         # >=0.9.0 for SwitchbotCurtain.set_position
         "PySwitchbot>=0.10.0,<0.13",
-        "paho-mqtt<2",
+        "aiomqtt<2",
     ],
     setup_requires=["setuptools_scm"],
     tests_require=["pytest"],

+ 80 - 50
switchbot_mqtt/__init__.py

@@ -18,12 +18,12 @@
 
 import logging
 import socket
+import ssl
 import typing
 
-import paho.mqtt.client
+import aiomqtt
 
 from switchbot_mqtt._actors import _ButtonAutomator, _CurtainMotor
-from switchbot_mqtt._actors.base import _MQTTCallbackUserdata
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -34,34 +34,54 @@ _MQTT_BIRTH_PAYLOAD = "online"
 _MQTT_LAST_WILL_PAYLOAD = "offline"
 
 
-def _mqtt_on_connect(
-    mqtt_client: paho.mqtt.client.Client,
-    userdata: _MQTTCallbackUserdata,
-    flags: typing.Dict[str, int],
-    return_code: int,
+async def _listen(
+    *,
+    mqtt_client: aiomqtt.Client,
+    topic_callbacks: typing.Iterable[typing.Tuple[str, typing.Callable]],
+    mqtt_topic_prefix: str,
+    retry_count: int,
+    device_passwords: typing.Dict[str, str],
+    fetch_device_info: bool,
 ) -> None:
-    # pylint: disable=unused-argument; callback
-    # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L441
-    assert return_code == 0, return_code  # connection accepted
-    mqtt_broker_host, mqtt_broker_port, *_ = mqtt_client.socket().getpeername()
-    # https://www.rfc-editor.org/rfc/rfc5952#section-6
-    _LOGGER.debug(
-        "connected to MQTT broker %s:%d",
-        f"[{mqtt_broker_host}]"
-        if mqtt_client.socket().family == socket.AF_INET6
-        else mqtt_broker_host,
-        mqtt_broker_port,
-    )
-    mqtt_client.publish(
-        topic=userdata.mqtt_topic_prefix + _MQTT_AVAILABILITY_TOPIC,
-        payload=_MQTT_BIRTH_PAYLOAD,
-        retain=True,
-    )
-    _ButtonAutomator.mqtt_subscribe(mqtt_client=mqtt_client, settings=userdata)
-    _CurtainMotor.mqtt_subscribe(mqtt_client=mqtt_client, settings=userdata)
+    async with mqtt_client.messages() as messages:
+        await mqtt_client.publish(
+            topic=mqtt_topic_prefix + _MQTT_AVAILABILITY_TOPIC,
+            payload=_MQTT_BIRTH_PAYLOAD,
+            retain=True,
+        )
+        async for message in messages:
+            for topic, callback in topic_callbacks:
+                if message.topic.matches(topic):
+                    await callback(
+                        mqtt_client=mqtt_client,
+                        message=message,
+                        mqtt_topic_prefix=mqtt_topic_prefix,
+                        retry_count=retry_count,
+                        device_passwords=device_passwords,
+                        fetch_device_info=fetch_device_info,
+                    )
+
+
+def _log_mqtt_connected(mqtt_client: aiomqtt.Client) -> None:
+    if _LOGGER.getEffectiveLevel() <= logging.DEBUG:
+        mqtt_socket = (
+            # aiomqtt neither exposes instance of paho.mqtt.client.Client nor socket publicly.
+            # level condition to avoid accessing protected `mqtt_client._client` in production.
+            # pylint: disable=protected-access
+            mqtt_client._client.socket()
+        )
+        (mqtt_broker_host, mqtt_broker_port, *_) = mqtt_socket.getpeername()
+        # https://github.com/sbtinstruments/aiomqtt/blob/v1.2.1/aiomqtt/client.py#L1089
+        _LOGGER.debug(
+            "connected to MQTT broker %s:%d",
+            f"[{mqtt_broker_host}]"
+            if mqtt_socket.family == socket.AF_INET6
+            else mqtt_broker_host,
+            mqtt_broker_port,
+        )
 
 
-def _run(  # pylint: disable=too-many-arguments
+async def _run(  # pylint: disable=too-many-arguments
     *,
     mqtt_host: str,
     mqtt_port: int,
@@ -73,33 +93,43 @@ def _run(  # pylint: disable=too-many-arguments
     device_passwords: typing.Dict[str, str],
     fetch_device_info: bool,
 ) -> None:
-    # https://pypi.org/project/paho-mqtt/
-    mqtt_client = paho.mqtt.client.Client(
-        userdata=_MQTTCallbackUserdata(
-            retry_count=retry_count,
-            device_passwords=device_passwords,
-            fetch_device_info=fetch_device_info,
-            mqtt_topic_prefix=mqtt_topic_prefix,
-        )
-    )
-    mqtt_client.on_connect = _mqtt_on_connect
     _LOGGER.info(
         "connecting to MQTT broker %s:%d (TLS %s)",
         mqtt_host,
         mqtt_port,
         "disabled" if mqtt_disable_tls else "enabled",
     )
-    if not mqtt_disable_tls:
-        mqtt_client.tls_set(ca_certs=None)  # enable tls trusting default system certs
-    if mqtt_username:
-        mqtt_client.username_pw_set(username=mqtt_username, password=mqtt_password)
-    elif mqtt_password:
+    if mqtt_password is not None and mqtt_username is None:
         raise ValueError("Missing MQTT username")
-    mqtt_client.will_set(
-        topic=mqtt_topic_prefix + _MQTT_AVAILABILITY_TOPIC,
-        payload=_MQTT_LAST_WILL_PAYLOAD,
-        retain=True,
-    )
-    mqtt_client.connect(host=mqtt_host, port=mqtt_port)
-    # https://github.com/eclipse/paho.mqtt.python/blob/master/src/paho/mqtt/client.py#L1740
-    mqtt_client.loop_forever()
+    async with aiomqtt.Client(  # raises aiomqtt.MqttError
+        hostname=mqtt_host,
+        port=mqtt_port,
+        # > The settings [...] usually represent a higher security level than
+        # > when calling the SSLContext constructor directly.
+        # https://web.archive.org/web/20230714183106/https://docs.python.org/3/library/ssl.html
+        tls_context=None if mqtt_disable_tls else ssl.create_default_context(),
+        username=None if mqtt_username is None else mqtt_username,
+        password=None if mqtt_password is None else mqtt_password,
+        will=aiomqtt.Will(
+            topic=mqtt_topic_prefix + _MQTT_AVAILABILITY_TOPIC,
+            payload=_MQTT_LAST_WILL_PAYLOAD,
+            retain=True,
+        ),
+    ) as mqtt_client:
+        _log_mqtt_connected(mqtt_client=mqtt_client)
+        topic_callbacks: typing.List[typing.Tuple[str, typing.Callable]] = []
+        for actor_class in (_ButtonAutomator, _CurtainMotor):
+            async for topic, callback in actor_class.mqtt_subscribe(
+                mqtt_client=mqtt_client,
+                mqtt_topic_prefix=mqtt_topic_prefix,
+                fetch_device_info=fetch_device_info,
+            ):
+                topic_callbacks.append((topic, callback))
+        await _listen(
+            mqtt_client=mqtt_client,
+            topic_callbacks=topic_callbacks,
+            mqtt_topic_prefix=mqtt_topic_prefix,
+            retry_count=retry_count,
+            device_passwords=device_passwords,
+            fetch_device_info=fetch_device_info,
+        )

+ 37 - 26
switchbot_mqtt/_actors/__init__.py

@@ -20,10 +20,10 @@ import logging
 import typing
 
 import bluepy.btle
-import paho.mqtt.client
+import aiomqtt
 import switchbot
 
-from switchbot_mqtt._actors.base import _MQTTCallbackUserdata, _MQTTControlledActor
+from switchbot_mqtt._actors.base import _MQTTControlledActor
 from switchbot_mqtt._utils import (
     _join_mqtt_topic_levels,
     _MQTTTopicLevel,
@@ -69,11 +69,11 @@ class _ButtonAutomator(_MQTTControlledActor):
     def _get_device(self) -> switchbot.SwitchbotDevice:
         return self.__device
 
-    def execute_command(
+    async def execute_command(
         self,
         *,
         mqtt_message_payload: bytes,
-        mqtt_client: paho.mqtt.client.Client,
+        mqtt_client: aiomqtt.Client,
         update_device_info: bool,
         mqtt_topic_prefix: str,
     ) -> None:
@@ -84,26 +84,30 @@ class _ButtonAutomator(_MQTTControlledActor):
             else:
                 _LOGGER.info("switchbot %s turned on", self._mac_address)
                 # https://www.home-assistant.io/integrations/switch.mqtt/#state_on
-                self.report_state(
+                await self.report_state(
                     mqtt_client=mqtt_client,
                     mqtt_topic_prefix=mqtt_topic_prefix,
                     state=b"ON",
                 )
                 if update_device_info:
-                    self._update_and_report_device_info(mqtt_client, mqtt_topic_prefix)
+                    await self._update_and_report_device_info(
+                        mqtt_client, mqtt_topic_prefix
+                    )
         # https://www.home-assistant.io/integrations/switch.mqtt/#payload_off
         elif mqtt_message_payload.lower() == b"off":
             if not self.__device.turn_off():
                 _LOGGER.error("failed to turn off switchbot %s", self._mac_address)
             else:
                 _LOGGER.info("switchbot %s turned off", self._mac_address)
-                self.report_state(
+                await self.report_state(
                     mqtt_client=mqtt_client,
                     mqtt_topic_prefix=mqtt_topic_prefix,
                     state=b"OFF",
                 )
                 if update_device_info:
-                    self._update_and_report_device_info(mqtt_client, mqtt_topic_prefix)
+                    await self._update_and_report_device_info(
+                        mqtt_client, mqtt_topic_prefix
+                    )
         else:
             _LOGGER.warning(
                 "unexpected payload %r (expected 'ON' or 'OFF')", mqtt_message_payload
@@ -154,9 +158,9 @@ class _CurtainMotor(_MQTTControlledActor):
     def _get_device(self) -> switchbot.SwitchbotDevice:
         return self.__device
 
-    def _report_position(
+    async def _report_position(
         self,
-        mqtt_client: paho.mqtt.client.Client,  # pylint: disable=duplicate-code; similar param list
+        mqtt_client: aiomqtt.Client,  # pylint: disable=duplicate-code; similar param list
         mqtt_topic_prefix: str,
     ) -> None:
         # > position_closed integer (Optional, default: 0)
@@ -166,31 +170,31 @@ class _CurtainMotor(_MQTTControlledActor):
         # SwitchbotCurtain.open() and .close() update the position optimistically,
         # SwitchbotCurtain.update() fetches the real position via bluetooth.
         # https://github.com/Danielhiversen/pySwitchbot/blob/0.10.0/switchbot/__init__.py#L202
-        self._mqtt_publish(
+        await self._mqtt_publish(
             topic_prefix=mqtt_topic_prefix,
             topic_levels=self._MQTT_POSITION_TOPIC_LEVELS,
             payload=str(int(self.__device.get_position())).encode(),
             mqtt_client=mqtt_client,
         )
 
-    def _update_and_report_device_info(  # pylint: disable=arguments-differ; report_position is optional
+    async def _update_and_report_device_info(  # pylint: disable=arguments-differ; report_position is optional
         self,
-        mqtt_client: paho.mqtt.client.Client,
+        mqtt_client: aiomqtt.Client,
         mqtt_topic_prefix: str,
         *,
         report_position: bool = True,
     ) -> None:
-        super()._update_and_report_device_info(mqtt_client, mqtt_topic_prefix)
+        await super()._update_and_report_device_info(mqtt_client, mqtt_topic_prefix)
         if report_position:
-            self._report_position(
+            await self._report_position(
                 mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix
             )
 
-    def execute_command(
+    async def execute_command(
         self,
         *,
         mqtt_message_payload: bytes,
-        mqtt_client: paho.mqtt.client.Client,
+        mqtt_client: aiomqtt.Client,
         update_device_info: bool,
         mqtt_topic_prefix: str,
     ) -> None:
@@ -203,7 +207,7 @@ class _CurtainMotor(_MQTTControlledActor):
                 _LOGGER.info("switchbot curtain %s opening", self._mac_address)
                 # > state_opening string (Optional, default: opening)
                 # https://www.home-assistant.io/integrations/cover.mqtt/#state_opening
-                self.report_state(
+                await self.report_state(
                     mqtt_client=mqtt_client,
                     mqtt_topic_prefix=mqtt_topic_prefix,
                     state=b"opening",
@@ -215,7 +219,7 @@ class _CurtainMotor(_MQTTControlledActor):
             else:
                 _LOGGER.info("switchbot curtain %s closing", self._mac_address)
                 # https://www.home-assistant.io/integrations/cover.mqtt/#state_closing
-                self.report_state(
+                await self.report_state(
                     mqtt_client=mqtt_client,
                     mqtt_topic_prefix=mqtt_topic_prefix,
                     state=b"closing",
@@ -229,7 +233,7 @@ class _CurtainMotor(_MQTTControlledActor):
                 # no "stopped" state mentioned at
                 # https://www.home-assistant.io/integrations/cover.mqtt/#configuration-variables
                 # https://community.home-assistant.io/t/mqtt-how-to-remove-retained-messages/79029/2
-                self.report_state(
+                await self.report_state(
                     mqtt_client=mqtt_client,
                     mqtt_topic_prefix=mqtt_topic_prefix,
                     state=b"",
@@ -242,18 +246,22 @@ class _CurtainMotor(_MQTTControlledActor):
                 mqtt_message_payload,
             )
         if report_device_info:
-            self._update_and_report_device_info(
+            await self._update_and_report_device_info(
                 mqtt_client=mqtt_client,
                 mqtt_topic_prefix=mqtt_topic_prefix,
                 report_position=report_position,
             )
 
     @classmethod
-    def _mqtt_set_position_callback(
+    async def _mqtt_set_position_callback(
         cls,
-        mqtt_client: paho.mqtt.client.Client,
-        userdata: _MQTTCallbackUserdata,
-        message: paho.mqtt.client.MQTTMessage,
+        *,
+        mqtt_client: aiomqtt.Client,
+        message: aiomqtt.Message,
+        mqtt_topic_prefix: str,
+        retry_count: int,
+        device_passwords: typing.Dict[str, str],
+        fetch_device_info: bool,
     ) -> None:
         # pylint: disable=unused-argument; callback
         # https://github.com/eclipse/paho.mqtt.python/blob/v1.6.1/src/paho/mqtt/client.py#L3556
@@ -263,11 +271,14 @@ class _CurtainMotor(_MQTTControlledActor):
             return
         actor = cls._init_from_topic(
             topic=message.topic,
+            mqtt_topic_prefix=mqtt_topic_prefix,
             expected_topic_levels=cls._MQTT_SET_POSITION_TOPIC_LEVELS,
-            settings=userdata,
+            retry_count=retry_count,
+            device_passwords=device_passwords,
         )
         if not actor:
             return  # warning in _init_from_topic
+        assert isinstance(message.payload, bytes), message.payload
         position_percent = int(message.payload.decode(), 10)
         if position_percent < 0 or position_percent > 100:
             _LOGGER.warning("invalid position %u%%, ignoring message", position_percent)

+ 67 - 58
switchbot_mqtt/_actors/base.py

@@ -26,14 +26,13 @@
 from __future__ import annotations  # PEP563 (default in python>=3.10)
 
 import abc
-import dataclasses
 import logging
 import queue
 import shlex
 import typing
 
+import aiomqtt
 import bluepy.btle
-import paho.mqtt.client
 import switchbot
 from switchbot_mqtt._utils import (
     _join_mqtt_topic_levels,
@@ -47,14 +46,6 @@ from switchbot_mqtt._utils import (
 _LOGGER = logging.getLogger(__name__)
 
 
-@dataclasses.dataclass
-class _MQTTCallbackUserdata:
-    retry_count: int
-    device_passwords: typing.Dict[str, str]
-    fetch_device_info: bool
-    mqtt_topic_prefix: str
-
-
 class _MQTTControlledActor(abc.ABC):
     MQTT_COMMAND_TOPIC_LEVELS: typing.Tuple[_MQTTTopicLevel, ...] = NotImplemented
     _MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS: typing.Tuple[
@@ -131,37 +122,40 @@ class _MQTTControlledActor(abc.ABC):
                 ) from exc
             raise
 
-    def _report_battery_level(
-        self, mqtt_client: paho.mqtt.client.Client, mqtt_topic_prefix: str
+    async def _report_battery_level(
+        self, mqtt_client: aiomqtt.Client, mqtt_topic_prefix: str
     ) -> None:
         # > battery: Percentage of battery that is left.
         # https://www.home-assistant.io/integrations/sensor/#device-class
-        self._mqtt_publish(
+        await self._mqtt_publish(
             topic_prefix=mqtt_topic_prefix,
             topic_levels=self._MQTT_BATTERY_PERCENTAGE_TOPIC_LEVELS,
             payload=str(self._get_device().get_battery_percent()).encode(),
             mqtt_client=mqtt_client,
         )
 
-    def _update_and_report_device_info(
-        self, mqtt_client: paho.mqtt.client.Client, mqtt_topic_prefix: str
+    async def _update_and_report_device_info(
+        self, mqtt_client: aiomqtt.Client, mqtt_topic_prefix: str
     ) -> None:
         self._update_device_info()
-        self._report_battery_level(
+        await self._report_battery_level(
             mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix
         )
 
     @classmethod
     def _init_from_topic(
         cls,
-        topic: str,
+        *,
+        topic: aiomqtt.Topic,
+        mqtt_topic_prefix: str,
         expected_topic_levels: typing.Collection[_MQTTTopicLevel],
-        settings: _MQTTCallbackUserdata,
+        retry_count: int,
+        device_passwords: typing.Dict[str, str],
     ) -> typing.Optional[_MQTTControlledActor]:
         try:
             mac_address = _parse_mqtt_topic(
-                topic=topic,
-                expected_prefix=settings.mqtt_topic_prefix,
+                topic=topic.value,
+                expected_prefix=mqtt_topic_prefix,
                 expected_levels=expected_topic_levels,
             )[_MQTTTopicPlaceholder.MAC_ADDRESS]
         except ValueError as exc:
@@ -172,17 +166,21 @@ class _MQTTControlledActor(abc.ABC):
             return None
         return cls(
             mac_address=mac_address,
-            retry_count=settings.retry_count,
-            password=settings.device_passwords.get(mac_address, None),
+            retry_count=retry_count,
+            password=device_passwords.get(mac_address, None),
         )
 
     @classmethod
-    def _mqtt_update_device_info_callback(
+    async def _mqtt_update_device_info_callback(
         # pylint: disable=duplicate-code; other callbacks with same params
         cls,
-        mqtt_client: paho.mqtt.client.Client,
-        userdata: _MQTTCallbackUserdata,
-        message: paho.mqtt.client.MQTTMessage,
+        *,
+        mqtt_client: aiomqtt.Client,
+        message: aiomqtt.Message,
+        mqtt_topic_prefix: str,
+        retry_count: int,
+        device_passwords: typing.Dict[str, str],
+        fetch_device_info: bool,
     ) -> None:
         # pylint: disable=unused-argument; callback
         # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L469
@@ -192,33 +190,39 @@ class _MQTTControlledActor(abc.ABC):
             return
         actor = cls._init_from_topic(
             topic=message.topic,
+            mqtt_topic_prefix=mqtt_topic_prefix,
             expected_topic_levels=cls._MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS,
-            settings=userdata,
+            retry_count=retry_count,
+            device_passwords=device_passwords,
         )
         if actor:
             # pylint: disable=protected-access; own instance
-            actor._update_and_report_device_info(
-                mqtt_client=mqtt_client, mqtt_topic_prefix=userdata.mqtt_topic_prefix
+            await actor._update_and_report_device_info(
+                mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix
             )
 
     @abc.abstractmethod
-    def execute_command(  # pylint: disable=duplicate-code; implementations
+    async def execute_command(  # pylint: disable=duplicate-code; implementations
         self,
         *,
         mqtt_message_payload: bytes,
-        mqtt_client: paho.mqtt.client.Client,
+        mqtt_client: aiomqtt.Client,
         update_device_info: bool,
         mqtt_topic_prefix: str,
     ) -> None:
         raise NotImplementedError()
 
     @classmethod
-    def _mqtt_command_callback(
+    async def _mqtt_command_callback(
         # pylint: disable=duplicate-code; other callbacks with same params
         cls,
-        mqtt_client: paho.mqtt.client.Client,
-        userdata: _MQTTCallbackUserdata,
-        message: paho.mqtt.client.MQTTMessage,
+        *,
+        mqtt_client: aiomqtt.Client,
+        message: aiomqtt.Message,
+        mqtt_topic_prefix: str,
+        retry_count: int,
+        device_passwords: typing.Dict[str, str],
+        fetch_device_info: bool,
     ) -> None:
         # pylint: disable=unused-argument; callback
         # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L469
@@ -228,15 +232,18 @@ class _MQTTControlledActor(abc.ABC):
             return
         actor = cls._init_from_topic(
             topic=message.topic,
+            mqtt_topic_prefix=mqtt_topic_prefix,
             expected_topic_levels=cls.MQTT_COMMAND_TOPIC_LEVELS,
-            settings=userdata,
+            retry_count=retry_count,
+            device_passwords=device_passwords,
         )
         if actor:
-            actor.execute_command(
+            assert isinstance(message.payload, bytes), message.payload
+            await actor.execute_command(
                 mqtt_message_payload=message.payload,
                 mqtt_client=mqtt_client,
-                update_device_info=userdata.fetch_device_info,
-                mqtt_topic_prefix=userdata.mqtt_topic_prefix,
+                update_device_info=fetch_device_info,
+                mqtt_topic_prefix=mqtt_topic_prefix,
             )
 
     @classmethod
@@ -257,28 +264,32 @@ class _MQTTControlledActor(abc.ABC):
         return callbacks
 
     @classmethod
-    def mqtt_subscribe(
-        cls, *, mqtt_client: paho.mqtt.client.Client, settings: _MQTTCallbackUserdata
-    ) -> None:
+    async def mqtt_subscribe(
+        cls,
+        *,
+        mqtt_client: aiomqtt.Client,
+        mqtt_topic_prefix: str,
+        fetch_device_info: bool,
+    ) -> typing.AsyncIterator[typing.Tuple[str, typing.Callable]]:
         for topic_levels, callback in cls._get_mqtt_message_callbacks(
-            enable_device_info_update_topic=settings.fetch_device_info
+            enable_device_info_update_topic=fetch_device_info
         ).items():
             topic = _join_mqtt_topic_levels(
-                topic_prefix=settings.mqtt_topic_prefix,
+                topic_prefix=mqtt_topic_prefix,
                 topic_levels=topic_levels,
                 mac_address="+",
             )
             _LOGGER.info("subscribing to MQTT topic %r", topic)
-            mqtt_client.subscribe(topic)
-            mqtt_client.message_callback_add(sub=topic, callback=callback)
+            await mqtt_client.subscribe(topic)
+            yield (topic, callback)
 
-    def _mqtt_publish(
+    async def _mqtt_publish(
         self,
         *,
         topic_prefix: str,
         topic_levels: typing.Iterable[_MQTTTopicLevel],
         payload: bytes,
-        mqtt_client: paho.mqtt.client.Client,
+        mqtt_client: aiomqtt.Client,
     ) -> None:
         topic = _join_mqtt_topic_levels(
             topic_prefix=topic_prefix,
@@ -287,24 +298,22 @@ class _MQTTControlledActor(abc.ABC):
         )
         # https://pypi.org/project/paho-mqtt/#publishing
         _LOGGER.debug("publishing topic=%s payload=%r", topic, payload)
-        message_info: paho.mqtt.client.MQTTMessageInfo = mqtt_client.publish(
-            topic=topic, payload=payload, retain=True
-        )
-        # wait before checking status?
-        if message_info.rc != paho.mqtt.client.MQTT_ERR_SUCCESS:
+        try:
+            await mqtt_client.publish(topic=topic, payload=payload, retain=True)
+        except aiomqtt.MqttCodeError as exc:
             _LOGGER.error(
-                "Failed to publish MQTT message on topic %s (rc=%d)",
+                "Failed to publish MQTT message on topic %s: aiomqtt.MqttCodeError %s",
                 topic,
-                message_info.rc,
+                exc,
             )
 
-    def report_state(
+    async def report_state(
         self,
         state: bytes,
-        mqtt_client: paho.mqtt.client.Client,
+        mqtt_client: aiomqtt.Client,
         mqtt_topic_prefix: str,
     ) -> None:
-        self._mqtt_publish(
+        await self._mqtt_publish(
             topic_prefix=mqtt_topic_prefix,
             topic_levels=self.MQTT_STATE_TOPIC_LEVELS,
             payload=state,

+ 17 - 13
switchbot_mqtt/_cli.py

@@ -17,6 +17,7 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
 import argparse
+import asyncio
 import json
 import logging
 import os
@@ -154,17 +155,20 @@ def _main() -> None:
         device_passwords = json.loads(args.device_password_path.read_text())
     else:
         device_passwords = {}
-    switchbot_mqtt._run(  # pylint: disable=protected-access; internal
-        mqtt_host=args.mqtt_host,
-        mqtt_port=mqtt_port,
-        mqtt_disable_tls=not args.mqtt_enable_tls,
-        mqtt_username=args.mqtt_username,
-        mqtt_password=mqtt_password,
-        mqtt_topic_prefix=args.mqtt_topic_prefix,
-        retry_count=args.retry_count,
-        device_passwords=device_passwords,
-        fetch_device_info=args.fetch_device_info
-        # > In formal language theory, the empty string, [...], is the unique string of length zero.
-        # https://en.wikipedia.org/wiki/Empty_string
-        or bool(os.environ.get("FETCH_DEVICE_INFO")),
+    asyncio.run(
+        switchbot_mqtt._run(  # pylint: disable=protected-access; internal
+            mqtt_host=args.mqtt_host,
+            mqtt_port=mqtt_port,
+            mqtt_disable_tls=not args.mqtt_enable_tls,
+            mqtt_username=args.mqtt_username,
+            mqtt_password=mqtt_password,
+            mqtt_topic_prefix=args.mqtt_topic_prefix,
+            retry_count=args.retry_count,
+            device_passwords=device_passwords,
+            fetch_device_info=args.fetch_device_info
+            # > In formal language theory, the empty string, [...],
+            # > is the unique string of length zero.
+            # https://en.wikipedia.org/wiki/Empty_string
+            or bool(os.environ.get("FETCH_DEVICE_INFO")),
+        )
     )

+ 14 - 4
tests/test_actor_base.py

@@ -35,7 +35,8 @@ def test_abstract() -> None:
         )
 
 
-def test_execute_command_abstract() -> None:
+@pytest.mark.asyncio
+async def test_execute_command_abstract() -> None:
     class _ActorMock(switchbot_mqtt._actors.base._MQTTControlledActor):
         # pylint: disable=duplicate-code
         def __init__(
@@ -45,7 +46,7 @@ def test_execute_command_abstract() -> None:
                 mac_address=mac_address, retry_count=retry_count, password=password
             )
 
-        def execute_command(
+        async def execute_command(
             self,
             *,
             mqtt_message_payload: bytes,
@@ -54,7 +55,7 @@ def test_execute_command_abstract() -> None:
             mqtt_topic_prefix: str,
         ) -> None:
             assert 21
-            super().execute_command(
+            await super().execute_command(  # type: ignore
                 mqtt_message_payload=mqtt_message_payload,
                 mqtt_client=mqtt_client,
                 update_device_info=update_device_info,
@@ -65,9 +66,18 @@ def test_execute_command_abstract() -> None:
             assert 42
             return super()._get_device()
 
+    with pytest.raises(TypeError) as exc_info:
+        # pylint: disable=abstract-class-instantiated
+        switchbot_mqtt._actors.base._MQTTControlledActor(  # type: ignore
+            mac_address="aa:bb:cc:dd:ee:ff", retry_count=42, password=None
+        )
+    exc_info.match(
+        r"^Can't instantiate abstract class _MQTTControlledActor"
+        r" with abstract methods __init__, _get_device, execute_command$"
+    )
     actor = _ActorMock(mac_address="aa:bb:cc:dd:ee:ff", retry_count=42, password=None)
     with pytest.raises(NotImplementedError):
-        actor.execute_command(
+        await actor.execute_command(
             mqtt_message_payload=b"dummy",
             mqtt_client="dummy",
             update_device_info=True,

File diff suppressed because it is too large
+ 344 - 264
tests/test_mqtt.py


+ 19 - 14
tests/test_switchbot_button_automator.py

@@ -43,27 +43,29 @@ def test_get_mqtt_battery_percentage_topic(prefix: str, mac_address: str) -> Non
     )
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("topic_prefix", ["homeassistant/", "prefix-", ""])
 @pytest.mark.parametrize(("battery_percent", "battery_percent_encoded"), [(42, b"42")])
-def test__update_and_report_device_info(
+async def test__update_and_report_device_info(
     topic_prefix: str, battery_percent: int, battery_percent_encoded: bytes
 ) -> None:
     with unittest.mock.patch("switchbot.SwitchbotCurtain.__init__", return_value=None):
         actor = _ButtonAutomator(mac_address="dummy", retry_count=21, password=None)
     actor._get_device()._switchbot_device_data = {"data": {"battery": battery_percent}}
-    mqtt_client_mock = unittest.mock.MagicMock()
+    mqtt_client_mock = unittest.mock.AsyncMock()
     with unittest.mock.patch("switchbot.Switchbot.update") as update_mock:
-        actor._update_and_report_device_info(
+        await actor._update_and_report_device_info(
             mqtt_client=mqtt_client_mock, mqtt_topic_prefix=topic_prefix
         )
     update_mock.assert_called_once_with()
-    mqtt_client_mock.publish.assert_called_once_with(
+    mqtt_client_mock.publish.assert_awaited_once_with(
         topic=f"{topic_prefix}switch/switchbot/dummy/battery-percentage",
         payload=battery_percent_encoded,
         retain=True,
     )
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("topic_prefix", ["homeassistant/"])
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"])
 @pytest.mark.parametrize("password", (None, "secret"))
@@ -81,7 +83,7 @@ def test__update_and_report_device_info(
 )
 @pytest.mark.parametrize("update_device_info", [True, False])
 @pytest.mark.parametrize("command_successful", [True, False])
-def test_execute_command(
+async def test_execute_command(
     caplog: _pytest.logging.LogCaptureFixture,
     topic_prefix: str,
     mac_address: str,
@@ -98,6 +100,7 @@ def test_execute_command(
         actor = _ButtonAutomator(
             mac_address=mac_address, retry_count=retry_count, password=password
         )
+        mqtt_client = unittest.mock.Mock()
         with unittest.mock.patch.object(
             actor, "report_state"
         ) as report_mock, unittest.mock.patch(
@@ -105,8 +108,8 @@ def test_execute_command(
         ) as action_mock, unittest.mock.patch.object(
             actor, "_update_and_report_device_info"
         ) as update_device_info_mock:
-            actor.execute_command(
-                mqtt_client="dummy",
+            await actor.execute_command(
+                mqtt_client=mqtt_client,
                 mqtt_message_payload=message_payload,
                 update_device_info=update_device_info,
                 mqtt_topic_prefix=topic_prefix,
@@ -124,7 +127,7 @@ def test_execute_command(
             )
         ]
         report_mock.assert_called_once_with(
-            mqtt_client="dummy",
+            mqtt_client=mqtt_client,
             mqtt_topic_prefix=topic_prefix,
             state=message_payload.upper(),
         )
@@ -141,9 +144,10 @@ def test_execute_command(
         update_device_info_mock.assert_not_called()
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"])
 @pytest.mark.parametrize("message_payload", [b"EIN", b""])
-def test_execute_command_invalid_payload(
+async def test_execute_command_invalid_payload(
     caplog: _pytest.logging.LogCaptureFixture, mac_address: str, message_payload: bytes
 ) -> None:
     with unittest.mock.patch("switchbot.Switchbot") as device_mock, caplog.at_level(
@@ -151,8 +155,8 @@ def test_execute_command_invalid_payload(
     ):
         actor = _ButtonAutomator(mac_address=mac_address, retry_count=21, password=None)
         with unittest.mock.patch.object(actor, "report_state") as report_mock:
-            actor.execute_command(
-                mqtt_client="dummy",
+            await actor.execute_command(
+                mqtt_client=unittest.mock.Mock(),
                 mqtt_message_payload=message_payload,
                 update_device_info=True,
                 mqtt_topic_prefix="dummy",
@@ -169,9 +173,10 @@ def test_execute_command_invalid_payload(
     ]
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"])
 @pytest.mark.parametrize("message_payload", [b"ON", b"OFF"])
-def test_execute_command_bluetooth_error(
+async def test_execute_command_bluetooth_error(
     caplog: _pytest.logging.LogCaptureFixture, mac_address: str, message_payload: bytes
 ) -> None:
     """
@@ -186,10 +191,10 @@ def test_execute_command_bluetooth_error(
             f"Failed to connect to peripheral {mac_address}, addr type: random"
         ),
     ), caplog.at_level(logging.ERROR):
-        _ButtonAutomator(
+        await _ButtonAutomator(
             mac_address=mac_address, retry_count=0, password=None
         ).execute_command(
-            mqtt_client="dummy",
+            mqtt_client=unittest.mock.Mock(),
             mqtt_message_payload=message_payload,
             update_device_info=True,
             mqtt_topic_prefix="dummy",

+ 43 - 27
tests/test_switchbot_curtain_motor.py

@@ -50,6 +50,7 @@ def test_get_mqtt_position_topic(mac_address: str) -> None:
     )
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize(
     "mac_address",
     ("aa:bb:cc:dd:ee:ff", "aa:bb:cc:dd:ee:gg"),
@@ -57,7 +58,7 @@ def test_get_mqtt_position_topic(mac_address: str) -> None:
 @pytest.mark.parametrize(
     ("position", "expected_payload"), [(0, b"0"), (100, b"100"), (42, b"42")]
 )
-def test__report_position(
+async def test__report_position(
     caplog: _pytest.logging.LogCaptureFixture,
     mac_address: str,
     position: int,
@@ -78,13 +79,16 @@ def test__report_position(
         # https://github.com/Danielhiversen/pySwitchbot/blob/0.10.0/switchbot/__init__.py#L150
         reverse_mode=True,
     )
+    mqtt_client = unittest.mock.Mock()
     with unittest.mock.patch.object(
         actor, "_mqtt_publish"
     ) as publish_mock, unittest.mock.patch(
         "switchbot.SwitchbotCurtain.get_position", return_value=position
     ):
-        actor._report_position(mqtt_client="dummy", mqtt_topic_prefix="topic-prefix")
-    publish_mock.assert_called_once_with(
+        await actor._report_position(
+            mqtt_client=mqtt_client, mqtt_topic_prefix="topic-prefix"
+        )
+    publish_mock.assert_awaited_once_with(
         topic_prefix="topic-prefix",
         topic_levels=(
             "cover",
@@ -93,13 +97,14 @@ def test__report_position(
             "position",
         ),
         payload=expected_payload,
-        mqtt_client="dummy",
+        mqtt_client=mqtt_client,
     )
     assert not caplog.record_tuples
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("position", ("", 'lambda: print("")'))
-def test__report_position_invalid(
+async def test__report_position_invalid(
     caplog: _pytest.logging.LogCaptureFixture, position: str
 ) -> None:
     with unittest.mock.patch(
@@ -115,15 +120,18 @@ def test__report_position_invalid(
     ), pytest.raises(
         ValueError
     ):
-        actor._report_position(mqtt_client="dummy", mqtt_topic_prefix="dummy2")
+        await actor._report_position(
+            mqtt_client=unittest.mock.Mock(), mqtt_topic_prefix="dummy2"
+        )
     publish_mock.assert_not_called()
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("topic_prefix", ["", "homeassistant/"])
 @pytest.mark.parametrize(("battery_percent", "battery_percent_encoded"), [(42, b"42")])
 @pytest.mark.parametrize("report_position", [True, False])
 @pytest.mark.parametrize(("position", "position_encoded"), [(21, b"21")])
-def test__update_and_report_device_info(
+async def test__update_and_report_device_info(
     topic_prefix: str,
     report_position: bool,
     battery_percent: int,
@@ -136,22 +144,22 @@ def test__update_and_report_device_info(
     actor._get_device()._switchbot_device_data = {
         "data": {"battery": battery_percent, "position": position}
     }
-    mqtt_client_mock = unittest.mock.MagicMock()
+    mqtt_client_mock = unittest.mock.AsyncMock()
     with unittest.mock.patch("switchbot.SwitchbotCurtain.update") as update_mock:
-        actor._update_and_report_device_info(
+        await actor._update_and_report_device_info(
             mqtt_client=mqtt_client_mock,
             mqtt_topic_prefix=topic_prefix,
             report_position=report_position,
         )
     update_mock.assert_called_once_with()
-    assert mqtt_client_mock.publish.call_count == (1 + report_position)
+    assert mqtt_client_mock.publish.await_count == (1 + report_position)
     assert (
         unittest.mock.call(
             topic=topic_prefix + "cover/switchbot-curtain/dummy/battery-percentage",
             payload=battery_percent_encoded,
             retain=True,
         )
-        in mqtt_client_mock.publish.call_args_list
+        in mqtt_client_mock.publish.await_args_list
     )
     if report_position:
         assert (
@@ -160,10 +168,11 @@ def test__update_and_report_device_info(
                 payload=position_encoded,
                 retain=True,
             )
-            in mqtt_client_mock.publish.call_args_list
+            in mqtt_client_mock.publish.await_args_list
         )
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize(
     "exception",
     [
@@ -171,18 +180,21 @@ def test__update_and_report_device_info(
         bluepy.btle.BTLEManagementError("test"),
     ],
 )
-def test__update_and_report_device_info_update_error(exception: Exception) -> None:
+async def test__update_and_report_device_info_update_error(
+    exception: Exception,
+) -> None:
     actor = _CurtainMotor(mac_address="dummy", retry_count=21, password=None)
     mqtt_client_mock = unittest.mock.MagicMock()
     with unittest.mock.patch.object(
         actor._get_device(), "update", side_effect=exception
     ), pytest.raises(type(exception)):
-        actor._update_and_report_device_info(
+        await actor._update_and_report_device_info(
             mqtt_client_mock, mqtt_topic_prefix="dummy", report_position=True
         )
     mqtt_client_mock.publish.assert_not_called()
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("topic_prefix", ["topic-prfx"])
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"])
 @pytest.mark.parametrize("password", ["pa$$word", None])
@@ -203,7 +215,7 @@ def test__update_and_report_device_info_update_error(exception: Exception) -> No
 )
 @pytest.mark.parametrize("update_device_info", [True, False])
 @pytest.mark.parametrize("command_successful", [True, False])
-def test_execute_command(
+async def test_execute_command(
     caplog: _pytest.logging.LogCaptureFixture,
     topic_prefix: str,
     mac_address: str,
@@ -214,12 +226,14 @@ def test_execute_command(
     update_device_info: bool,
     command_successful: bool,
 ) -> None:
+    # pylint: disable=too-many-locals
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain.__init__", return_value=None
     ) as device_init_mock, caplog.at_level(logging.INFO):
         actor = _CurtainMotor(
             mac_address=mac_address, retry_count=retry_count, password=password
         )
+        mqtt_client = unittest.mock.Mock()
         with unittest.mock.patch.object(
             actor, "report_state"
         ) as report_mock, unittest.mock.patch(
@@ -227,8 +241,8 @@ def test_execute_command(
         ) as action_mock, unittest.mock.patch.object(
             actor, "_update_and_report_device_info"
         ) as update_device_info_mock:
-            actor.execute_command(
-                mqtt_client="dummy",
+            await actor.execute_command(
+                mqtt_client=mqtt_client,
                 mqtt_message_payload=message_payload,
                 update_device_info=update_device_info,
                 mqtt_topic_prefix=topic_prefix,
@@ -248,8 +262,8 @@ def test_execute_command(
                 f"switchbot curtain {mac_address} {state_str}",
             )
         ]
-        report_mock.assert_called_once_with(
-            mqtt_client="dummy",
+        report_mock.assert_awaited_once_with(
+            mqtt_client=mqtt_client,
             mqtt_topic_prefix=topic_prefix,
             # https://www.home-assistant.io/integrations/cover.mqtt/#state_opening
             state={b"open": b"opening", b"close": b"closing", b"stop": b""}[
@@ -266,8 +280,8 @@ def test_execute_command(
         ]
         report_mock.assert_not_called()
     if update_device_info and command_successful:
-        update_device_info_mock.assert_called_once_with(
-            mqtt_client="dummy",
+        update_device_info_mock.assert_awaited_once_with(
+            mqtt_client=mqtt_client,
             report_position=(action_name == "switchbot.SwitchbotCurtain.stop"),
             mqtt_topic_prefix=topic_prefix,
         )
@@ -275,10 +289,11 @@ def test_execute_command(
         update_device_info_mock.assert_not_called()
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"])
 @pytest.mark.parametrize("password", ["secret"])
 @pytest.mark.parametrize("message_payload", [b"OEFFNEN", b""])
-def test_execute_command_invalid_payload(
+async def test_execute_command_invalid_payload(
     caplog: _pytest.logging.LogCaptureFixture,
     mac_address: str,
     password: str,
@@ -289,8 +304,8 @@ def test_execute_command_invalid_payload(
     ) as device_mock, caplog.at_level(logging.INFO):
         actor = _CurtainMotor(mac_address=mac_address, retry_count=7, password=password)
         with unittest.mock.patch.object(actor, "report_state") as report_mock:
-            actor.execute_command(
-                mqtt_client="dummy",
+            await actor.execute_command(
+                mqtt_client=unittest.mock.Mock(),
                 mqtt_message_payload=message_payload,
                 update_device_info=True,
                 mqtt_topic_prefix="dummy",
@@ -309,9 +324,10 @@ def test_execute_command_invalid_payload(
     ]
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"])
 @pytest.mark.parametrize("message_payload", [b"OPEN", b"CLOSE", b"STOP"])
-def test_execute_command_bluetooth_error(
+async def test_execute_command_bluetooth_error(
     caplog: _pytest.logging.LogCaptureFixture, mac_address: str, message_payload: bytes
 ) -> None:
     """
@@ -326,10 +342,10 @@ def test_execute_command_bluetooth_error(
             f"Failed to connect to peripheral {mac_address}, addr type: random"
         ),
     ), caplog.at_level(logging.ERROR):
-        _CurtainMotor(
+        await _CurtainMotor(
             mac_address=mac_address, retry_count=0, password="secret"
         ).execute_command(
-            mqtt_client="dummy",
+            mqtt_client=unittest.mock.Mock(),
             mqtt_message_payload=message_payload,
             update_device_info=True,
             mqtt_topic_prefix="dummy",

+ 96 - 75
tests/test_switchbot_curtain_motor_position.py

@@ -19,34 +19,34 @@
 import logging
 import unittest.mock
 
+import aiomqtt
 import _pytest.logging  # pylint: disable=import-private-name; typing
 import pytest
-from paho.mqtt.client import MQTTMessage
 
 # pylint: disable=import-private-name; internal
 from switchbot_mqtt._actors import _CurtainMotor
-from switchbot_mqtt._actors.base import _MQTTCallbackUserdata
 
 # pylint: disable=protected-access
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize(
     ("topic", "payload", "expected_mac_address", "expected_position_percent"),
     [
         (
-            b"home/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent",
+            "home/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent",
             b"42",
             "aa:bb:cc:dd:ee:ff",
             42,
         ),
         (
-            b"home/cover/switchbot-curtain/11:22:33:44:55:66/position/set-percent",
+            "home/cover/switchbot-curtain/11:22:33:44:55:66/position/set-percent",
             b"0",
             "11:22:33:44:55:66",
             0,
         ),
         (
-            b"home/cover/switchbot-curtain/11:22:33:44:55:66/position/set-percent",
+            "home/cover/switchbot-curtain/11:22:33:44:55:66/position/set-percent",
             b"100",
             "11:22:33:44:55:66",
             100,
@@ -54,27 +54,27 @@ from switchbot_mqtt._actors.base import _MQTTCallbackUserdata
     ],
 )
 @pytest.mark.parametrize("retry_count", (3, 42))
-def test__mqtt_set_position_callback(
+async def test__mqtt_set_position_callback(
     caplog: _pytest.logging.LogCaptureFixture,
-    topic: bytes,
+    topic: str,
     payload: bytes,
     expected_mac_address: str,
     retry_count: int,
     expected_position_percent: int,
 ) -> None:
-    callback_userdata = _MQTTCallbackUserdata(
-        retry_count=retry_count,
-        device_passwords={},
-        fetch_device_info=False,
-        mqtt_topic_prefix="home/",
+    message = aiomqtt.Message(
+        topic=topic, payload=payload, qos=0, retain=False, mid=0, properties=None
     )
-    message = MQTTMessage(topic=topic)
-    message.payload = payload
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain"
     ) as device_init_mock, caplog.at_level(logging.DEBUG):
-        _CurtainMotor._mqtt_set_position_callback(
-            mqtt_client="client dummy", userdata=callback_userdata, message=message
+        await _CurtainMotor._mqtt_set_position_callback(
+            mqtt_client=unittest.mock.Mock(),
+            message=message,
+            retry_count=retry_count,
+            device_passwords={},
+            fetch_device_info=False,
+            mqtt_topic_prefix="home/",
         )
     device_init_mock.assert_called_once_with(
         mac=expected_mac_address,
@@ -99,25 +99,28 @@ def test__mqtt_set_position_callback(
     ]
 
 
-def test__mqtt_set_position_callback_ignore_retained(
+@pytest.mark.asyncio
+async def test__mqtt_set_position_callback_ignore_retained(
     caplog: _pytest.logging.LogCaptureFixture,
 ) -> None:
-    callback_userdata = _MQTTCallbackUserdata(
-        retry_count=3,
-        device_passwords={},
-        fetch_device_info=False,
-        mqtt_topic_prefix="whatever",
-    )
-    message = MQTTMessage(
-        topic=b"homeassistant/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent"
+    message = aiomqtt.Message(
+        topic="homeassistant/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent",
+        payload=b"42",
+        qos=0,
+        retain=True,
+        mid=0,
+        properties=None,
     )
-    message.payload = b"42"
-    message.retain = True
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain"
     ) as device_init_mock, caplog.at_level(logging.INFO):
-        _CurtainMotor._mqtt_set_position_callback(
-            mqtt_client="client dummy", userdata=callback_userdata, message=message
+        await _CurtainMotor._mqtt_set_position_callback(
+            mqtt_client=unittest.mock.Mock(),
+            message=message,
+            retry_count=3,
+            device_passwords={},
+            fetch_device_info=False,
+            mqtt_topic_prefix="whatever",
         )
     device_init_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -130,22 +133,28 @@ def test__mqtt_set_position_callback_ignore_retained(
     ]
 
 
-def test__mqtt_set_position_callback_unexpected_topic(
+@pytest.mark.asyncio
+async def test__mqtt_set_position_callback_unexpected_topic(
     caplog: _pytest.logging.LogCaptureFixture,
 ) -> None:
-    callback_userdata = _MQTTCallbackUserdata(
-        retry_count=3,
-        device_passwords={},
-        fetch_device_info=False,
-        mqtt_topic_prefix="",
+    message = aiomqtt.Message(
+        topic="switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set",
+        payload=b"42",
+        qos=0,
+        retain=False,
+        mid=0,
+        properties=None,
     )
-    message = MQTTMessage(topic=b"switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set")
-    message.payload = b"42"
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain"
     ) as device_init_mock, caplog.at_level(logging.INFO):
-        _CurtainMotor._mqtt_set_position_callback(
-            mqtt_client="client dummy", userdata=callback_userdata, message=message
+        await _CurtainMotor._mqtt_set_position_callback(
+            mqtt_client=unittest.mock.Mock(),
+            message=message,
+            retry_count=3,
+            device_passwords={},
+            fetch_device_info=False,
+            mqtt_topic_prefix="",
         )
     device_init_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -157,24 +166,28 @@ def test__mqtt_set_position_callback_unexpected_topic(
     ]
 
 
-def test__mqtt_set_position_callback_invalid_mac_address(
+@pytest.mark.asyncio
+async def test__mqtt_set_position_callback_invalid_mac_address(
     caplog: _pytest.logging.LogCaptureFixture,
 ) -> None:
-    callback_userdata = _MQTTCallbackUserdata(
-        retry_count=3,
-        device_passwords={},
-        fetch_device_info=False,
-        mqtt_topic_prefix="tnatsissaemoh/",
-    )
-    message = MQTTMessage(
-        topic=b"tnatsissaemoh/cover/switchbot-curtain/aa:bb:cc:dd:ee/position/set-percent"
+    message = aiomqtt.Message(
+        topic="tnatsissaemoh/cover/switchbot-curtain/aa:bb:cc:dd:ee/position/set-percent",
+        payload=b"42",
+        qos=0,
+        retain=False,
+        mid=0,
+        properties=None,
     )
-    message.payload = b"42"
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain"
     ) as device_init_mock, caplog.at_level(logging.INFO):
-        _CurtainMotor._mqtt_set_position_callback(
-            mqtt_client="client dummy", userdata=callback_userdata, message=message
+        await _CurtainMotor._mqtt_set_position_callback(
+            mqtt_client=unittest.mock.Mock(),
+            message=message,
+            retry_count=3,
+            device_passwords={},
+            fetch_device_info=False,
+            mqtt_topic_prefix="tnatsissaemoh/",
         )
     device_init_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -186,26 +199,30 @@ def test__mqtt_set_position_callback_invalid_mac_address(
     ]
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("payload", [b"-1", b"123"])
-def test__mqtt_set_position_callback_invalid_position(
+async def test__mqtt_set_position_callback_invalid_position(
     caplog: _pytest.logging.LogCaptureFixture,
     payload: bytes,
 ) -> None:
-    callback_userdata = _MQTTCallbackUserdata(
-        retry_count=3,
-        device_passwords={},
-        fetch_device_info=False,
-        mqtt_topic_prefix="homeassistant/",
+    message = aiomqtt.Message(
+        topic="homeassistant/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent",
+        payload=payload,
+        qos=0,
+        retain=False,
+        mid=0,
+        properties=None,
     )
-    message = MQTTMessage(
-        topic=b"homeassistant/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent"
-    )
-    message.payload = payload
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain"
     ) as device_init_mock, caplog.at_level(logging.INFO):
-        _CurtainMotor._mqtt_set_position_callback(
-            mqtt_client="client dummy", userdata=callback_userdata, message=message
+        await _CurtainMotor._mqtt_set_position_callback(
+            mqtt_client=unittest.mock.Mock(),
+            message=message,
+            retry_count=3,
+            device_passwords={},
+            fetch_device_info=False,
+            mqtt_topic_prefix="homeassistant/",
         )
     device_init_mock.assert_called_once()
     device_init_mock().set_position.assert_not_called()
@@ -218,26 +235,30 @@ def test__mqtt_set_position_callback_invalid_position(
     ]
 
 
-def test__mqtt_set_position_callback_command_failed(
+@pytest.mark.asyncio
+async def test__mqtt_set_position_callback_command_failed(
     caplog: _pytest.logging.LogCaptureFixture,
 ) -> None:
-    callback_userdata = _MQTTCallbackUserdata(
-        retry_count=3,
-        device_passwords={},
-        fetch_device_info=False,
-        mqtt_topic_prefix="",
-    )
-    message = MQTTMessage(
-        topic=b"cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent"
+    message = aiomqtt.Message(
+        topic="cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent",
+        payload=b"21",
+        qos=0,
+        retain=False,
+        mid=0,
+        properties=None,
     )
-    message.payload = b"21"
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain"
     ) as device_init_mock, caplog.at_level(logging.INFO):
         device_init_mock().set_position.return_value = False
         device_init_mock.reset_mock()
-        _CurtainMotor._mqtt_set_position_callback(
-            mqtt_client="client dummy", userdata=callback_userdata, message=message
+        await _CurtainMotor._mqtt_set_position_callback(
+            mqtt_client=unittest.mock.Mock(),
+            message=message,
+            retry_count=3,
+            device_passwords={},
+            fetch_device_info=False,
+            mqtt_topic_prefix="",
         )
     device_init_mock.assert_called_once()
     device_init_mock().set_position.assert_called_with(21)

Some files were not shown because too many files changed in this diff