Browse Source

subscribe to topic `homeassistant/cover/switchbot-curtain/+/position/set-percent`

Fabian Peter Hammerle 2 years ago
parent
commit
1a9a7d4722

+ 27 - 2
switchbot_mqtt/_actors/__init__.py

@@ -23,7 +23,7 @@ import bluepy.btle
 import paho.mqtt.client
 import switchbot
 
-from switchbot_mqtt._actors._base import _MQTTControlledActor
+from switchbot_mqtt._actors._base import _MQTTCallbackUserdata, _MQTTControlledActor
 from switchbot_mqtt._utils import (
     _join_mqtt_topic_levels,
     _MQTTTopicLevel,
@@ -119,9 +119,13 @@ class _ButtonAutomator(_MQTTControlledActor):
 
 
 class _CurtainMotor(_MQTTControlledActor):
-    # https://www.home-assistant.io/integrations/cover.mqtt/
 
+    # https://www.home-assistant.io/integrations/cover.mqtt/
     MQTT_COMMAND_TOPIC_LEVELS = _CURTAIN_TOPIC_LEVELS_PREFIX + ["set"]
+    _MQTT_SET_POSITION_TOPIC_LEVELS = tuple(_CURTAIN_TOPIC_LEVELS_PREFIX) + (
+        "position",
+        "set-percent",
+    )
     _MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS = _CURTAIN_TOPIC_LEVELS_PREFIX + [
         "request-device-info"
     ]
@@ -221,3 +225,24 @@ class _CurtainMotor(_MQTTControlledActor):
             self._update_and_report_device_info(
                 mqtt_client=mqtt_client, report_position=report_position
             )
+
+    @classmethod
+    def _mqtt_set_position_callback(
+        cls,
+        mqtt_client: paho.mqtt.client.Client,
+        userdata: _MQTTCallbackUserdata,
+        message: paho.mqtt.client.MQTTMessage,
+    ) -> None:
+        raise NotImplementedError()
+
+    @classmethod
+    def _get_mqtt_message_callbacks(
+        cls,
+        *,
+        enable_device_info_update_topic: bool,
+    ) -> typing.Dict[typing.Tuple[_MQTTTopicLevel, ...], typing.Callable]:
+        callbacks = super()._get_mqtt_message_callbacks(
+            enable_device_info_update_topic=enable_device_info_update_topic
+        )
+        callbacks[cls._MQTT_SET_POSITION_TOPIC_LEVELS] = cls._mqtt_set_position_callback
+        return callbacks

+ 20 - 9
switchbot_mqtt/_actors/_base.py

@@ -211,6 +211,23 @@ class _MQTTControlledActor(abc.ABC):
                 update_device_info=userdata.fetch_device_info,
             )
 
+    @classmethod
+    def _get_mqtt_message_callbacks(
+        cls,
+        *,
+        enable_device_info_update_topic: bool,
+    ) -> typing.Dict[typing.Tuple[_MQTTTopicLevel, ...], typing.Callable]:
+        # returning dict because `paho.mqtt.client.Client.message_callback_add` overwrites
+        # callbacks with same topic pattern
+        # https://github.com/eclipse/paho.mqtt.python/blob/v1.6.1/src/paho/mqtt/client.py#L2304
+        # https://github.com/eclipse/paho.mqtt.python/blob/v1.6.1/src/paho/mqtt/matcher.py#L19
+        callbacks = {tuple(cls.MQTT_COMMAND_TOPIC_LEVELS): cls._mqtt_command_callback}
+        if enable_device_info_update_topic:
+            callbacks[
+                tuple(cls._MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS)
+            ] = cls._mqtt_update_device_info_callback
+        return callbacks
+
     @classmethod
     def mqtt_subscribe(
         cls,
@@ -218,15 +235,9 @@ class _MQTTControlledActor(abc.ABC):
         *,
         enable_device_info_update_topic: bool,
     ) -> None:
-        topics = [(cls.MQTT_COMMAND_TOPIC_LEVELS, cls._mqtt_command_callback)]
-        if enable_device_info_update_topic:
-            topics.append(
-                (
-                    cls._MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS,
-                    cls._mqtt_update_device_info_callback,
-                )
-            )
-        for topic_levels, callback in topics:
+        for topic_levels, callback in cls._get_mqtt_message_callbacks(
+            enable_device_info_update_topic=enable_device_info_update_topic
+        ).items():
             topic = _join_mqtt_topic_levels(topic_levels, mac_address="+")
             _LOGGER.info("subscribing to MQTT topic %r", topic)
             mqtt_client.subscribe(topic)

+ 1 - 1
switchbot_mqtt/_utils.py

@@ -37,7 +37,7 @@ _MQTTTopicLevel = typing.Union[str, _MQTTTopicPlaceholder]
 
 
 def _join_mqtt_topic_levels(
-    topic_levels: typing.List[_MQTTTopicLevel], mac_address: str
+    topic_levels: typing.Iterable[_MQTTTopicLevel], mac_address: str
 ) -> str:
     return "/".join(
         mac_address if l == _MQTTTopicPlaceholder.MAC_ADDRESS else typing.cast(str, l)

+ 11 - 2
tests/test_mqtt.py

@@ -77,10 +77,11 @@ def test__run(
     with caplog.at_level(logging.DEBUG):
         mqtt_client_mock().on_connect(mqtt_client_mock(), userdata, {}, 0)
     subscribe_mock = mqtt_client_mock().subscribe
-    assert subscribe_mock.call_count == (4 if fetch_device_info else 2)
+    assert subscribe_mock.call_count == (5 if fetch_device_info else 3)
     for topic in [
         "homeassistant/switch/switchbot/+/set",
         "homeassistant/cover/switchbot-curtain/+/set",
+        "homeassistant/cover/switchbot-curtain/+/position/set-percent",
     ]:
         assert unittest.mock.call(topic) in subscribe_mock.call_args_list
     for topic in [
@@ -90,6 +91,14 @@ def test__run(
         assert (
             unittest.mock.call(topic) in subscribe_mock.call_args_list
         ) == fetch_device_info
+    callbacks = {
+        c[1]["sub"]: c[1]["callback"]
+        for c in mqtt_client_mock().message_callback_add.call_args_list
+    }
+    assert (  # pylint: disable=comparison-with-callable; intended
+        callbacks["homeassistant/cover/switchbot-curtain/+/position/set-percent"]
+        == _CurtainMotor._mqtt_set_position_callback
+    )
     mqtt_client_mock().loop_forever.assert_called_once_with()
     assert caplog.record_tuples[:2] == [
         (
@@ -103,7 +112,7 @@ def test__run(
             f"connected to MQTT broker {mqtt_host}:{mqtt_port}",
         ),
     ]
-    assert len(caplog.record_tuples) == (6 if fetch_device_info else 4)
+    assert len(caplog.record_tuples) == (7 if fetch_device_info else 5)
     assert (
         "switchbot_mqtt._actors._base",
         logging.INFO,

+ 12 - 0
tests/test_switchbot_curtain_motor.py

@@ -26,6 +26,7 @@ import pytest
 
 import switchbot_mqtt._utils
 from switchbot_mqtt._actors import _CurtainMotor
+from switchbot_mqtt._actors._base import _MQTTCallbackUserdata
 
 # pylint: disable=protected-access,
 # pylint: disable=too-many-arguments; these are tests, no API
@@ -330,3 +331,14 @@ def test_execute_command_bluetooth_error(
         logging.ERROR,
         f"failed to {message_payload.decode().lower()} switchbot curtain {mac_address}",
     )
+
+
+def test__mqtt_set_position_callback() -> None:
+    with pytest.raises(NotImplementedError):
+        _CurtainMotor._mqtt_set_position_callback(
+            mqtt_client="dummy",
+            userdata=_MQTTCallbackUserdata(
+                retry_count=3, device_passwords={}, fetch_device_info=False
+            ),
+            message=None,
+        )