Browse Source

refactor: move `_TOPIC_LEVELS_PREFIX` to `_MQTTCallbackUserdata.mqtt_topic_prefix` (to prepare parametrization via commmand-line argument)

https://github.com/fphammerle/switchbot-mqtt/issues/70
Fabian Peter Hammerle 2 years ago
parent
commit
4416a97213

+ 2 - 8
switchbot_mqtt/__init__.py

@@ -38,14 +38,8 @@ def _mqtt_on_connect(
     assert return_code == 0, return_code  # connection accepted
     mqtt_broker_host, mqtt_broker_port = mqtt_client.socket().getpeername()
     _LOGGER.debug("connected to MQTT broker %s:%d", mqtt_broker_host, mqtt_broker_port)
-    _ButtonAutomator.mqtt_subscribe(
-        mqtt_client=mqtt_client,
-        enable_device_info_update_topic=userdata.fetch_device_info,
-    )
-    _CurtainMotor.mqtt_subscribe(
-        mqtt_client=mqtt_client,
-        enable_device_info_update_topic=userdata.fetch_device_info,
-    )
+    _ButtonAutomator.mqtt_subscribe(mqtt_client=mqtt_client, settings=userdata)
+    _CurtainMotor.mqtt_subscribe(mqtt_client=mqtt_client, settings=userdata)
 
 
 def _run(

+ 59 - 20
switchbot_mqtt/_actors/__init__.py

@@ -32,14 +32,12 @@ from switchbot_mqtt._utils import (
 
 _LOGGER = logging.getLogger(__name__)
 
-# "homeassistant" for historic reason, may be parametrized in future
-_TOPIC_LEVELS_PREFIX: typing.Tuple[_MQTTTopicLevel] = ("homeassistant",)
-_BUTTON_TOPIC_LEVELS_PREFIX = _TOPIC_LEVELS_PREFIX + (
+_BUTTON_TOPIC_LEVELS_PREFIX = (
     "switch",
     "switchbot",
     _MQTTTopicPlaceholder.MAC_ADDRESS,
 )
-_CURTAIN_TOPIC_LEVELS_PREFIX = _TOPIC_LEVELS_PREFIX + (
+_CURTAIN_TOPIC_LEVELS_PREFIX = (
     "cover",
     "switchbot-curtain",
     _MQTTTopicPlaceholder.MAC_ADDRESS,
@@ -73,9 +71,11 @@ class _ButtonAutomator(_MQTTControlledActor):
 
     def execute_command(
         self,
+        *,
         mqtt_message_payload: bytes,
         mqtt_client: paho.mqtt.client.Client,
         update_device_info: bool,
+        mqtt_topic_prefix: str,
     ) -> None:
         # https://www.home-assistant.io/integrations/switch.mqtt/#payload_on
         if mqtt_message_payload.lower() == b"on":
@@ -84,18 +84,26 @@ class _ButtonAutomator(_MQTTControlledActor):
             else:
                 _LOGGER.info("switchbot %s turned on", self._mac_address)
                 # https://www.home-assistant.io/integrations/switch.mqtt/#state_on
-                self.report_state(mqtt_client=mqtt_client, state=b"ON")
+                self.report_state(
+                    mqtt_client=mqtt_client,
+                    mqtt_topic_prefix=mqtt_topic_prefix,
+                    state=b"ON",
+                )
                 if update_device_info:
-                    self._update_and_report_device_info(mqtt_client)
+                    self._update_and_report_device_info(mqtt_client, mqtt_topic_prefix)
         # https://www.home-assistant.io/integrations/switch.mqtt/#payload_off
         elif mqtt_message_payload.lower() == b"off":
             if not self.__device.turn_off():
                 _LOGGER.error("failed to turn off switchbot %s", self._mac_address)
             else:
                 _LOGGER.info("switchbot %s turned off", self._mac_address)
-                self.report_state(mqtt_client=mqtt_client, state=b"OFF")
+                self.report_state(
+                    mqtt_client=mqtt_client,
+                    mqtt_topic_prefix=mqtt_topic_prefix,
+                    state=b"OFF",
+                )
                 if update_device_info:
-                    self._update_and_report_device_info(mqtt_client)
+                    self._update_and_report_device_info(mqtt_client, mqtt_topic_prefix)
         else:
             _LOGGER.warning(
                 "unexpected payload %r (expected 'ON' or 'OFF')", mqtt_message_payload
@@ -106,7 +114,9 @@ class _CurtainMotor(_MQTTControlledActor):
 
     # https://www.home-assistant.io/integrations/cover.mqtt/
     MQTT_COMMAND_TOPIC_LEVELS = _CURTAIN_TOPIC_LEVELS_PREFIX + ("set",)
-    _MQTT_SET_POSITION_TOPIC_LEVELS = tuple(_CURTAIN_TOPIC_LEVELS_PREFIX) + (
+    _MQTT_SET_POSITION_TOPIC_LEVELS: typing.Tuple[
+        _MQTTTopicLevel, ...
+    ] = _CURTAIN_TOPIC_LEVELS_PREFIX + (
         "position",
         "set-percent",
     )
@@ -120,9 +130,11 @@ class _CurtainMotor(_MQTTControlledActor):
     _MQTT_POSITION_TOPIC_LEVELS = _CURTAIN_TOPIC_LEVELS_PREFIX + ("position",)
 
     @classmethod
-    def get_mqtt_position_topic(cls, mac_address: str) -> str:
+    def get_mqtt_position_topic(cls, prefix: str, mac_address: str) -> str:
         return _join_mqtt_topic_levels(
-            topic_levels=cls._MQTT_POSITION_TOPIC_LEVELS, mac_address=mac_address
+            topic_prefix=prefix,
+            topic_levels=cls._MQTT_POSITION_TOPIC_LEVELS,
+            mac_address=mac_address,
         )
 
     def __init__(
@@ -143,7 +155,11 @@ class _CurtainMotor(_MQTTControlledActor):
     def _get_device(self) -> switchbot.SwitchbotDevice:
         return self.__device
 
-    def _report_position(self, mqtt_client: paho.mqtt.client.Client) -> None:
+    def _report_position(
+        self,
+        mqtt_client: paho.mqtt.client.Client,
+        mqtt_topic_prefix: str,
+    ) -> None:
         # > position_closed integer (Optional, default: 0)
         # > position_open integer (Optional, default: 100)
         # https://www.home-assistant.io/integrations/cover.mqtt/#position_closed
@@ -152,23 +168,32 @@ class _CurtainMotor(_MQTTControlledActor):
         # SwitchbotCurtain.update() fetches the real position via bluetooth.
         # https://github.com/Danielhiversen/pySwitchbot/blob/0.10.0/switchbot/__init__.py#L202
         self._mqtt_publish(
+            topic_prefix=mqtt_topic_prefix,
             topic_levels=self._MQTT_POSITION_TOPIC_LEVELS,
             payload=str(int(self.__device.get_position())).encode(),
             mqtt_client=mqtt_client,
         )
 
     def _update_and_report_device_info(  # pylint: disable=arguments-differ; report_position is optional
-        self, mqtt_client: paho.mqtt.client.Client, *, report_position: bool = True
+        self,
+        mqtt_client: paho.mqtt.client.Client,
+        mqtt_topic_prefix: str,
+        *,
+        report_position: bool = True,
     ) -> None:
-        super()._update_and_report_device_info(mqtt_client)
+        super()._update_and_report_device_info(mqtt_client, mqtt_topic_prefix)
         if report_position:
-            self._report_position(mqtt_client=mqtt_client)
+            self._report_position(
+                mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix
+            )
 
     def execute_command(
         self,
+        *,
         mqtt_message_payload: bytes,
         mqtt_client: paho.mqtt.client.Client,
         update_device_info: bool,
+        mqtt_topic_prefix: str,
     ) -> None:
         # https://www.home-assistant.io/integrations/cover.mqtt/#payload_open
         report_device_info, report_position = False, False
@@ -179,7 +204,11 @@ class _CurtainMotor(_MQTTControlledActor):
                 _LOGGER.info("switchbot curtain %s opening", self._mac_address)
                 # > state_opening string (Optional, default: opening)
                 # https://www.home-assistant.io/integrations/cover.mqtt/#state_opening
-                self.report_state(mqtt_client=mqtt_client, state=b"opening")
+                self.report_state(
+                    mqtt_client=mqtt_client,
+                    mqtt_topic_prefix=mqtt_topic_prefix,
+                    state=b"opening",
+                )
                 report_device_info = update_device_info
         elif mqtt_message_payload.lower() == b"close":
             if not self.__device.close():
@@ -187,7 +216,11 @@ class _CurtainMotor(_MQTTControlledActor):
             else:
                 _LOGGER.info("switchbot curtain %s closing", self._mac_address)
                 # https://www.home-assistant.io/integrations/cover.mqtt/#state_closing
-                self.report_state(mqtt_client=mqtt_client, state=b"closing")
+                self.report_state(
+                    mqtt_client=mqtt_client,
+                    mqtt_topic_prefix=mqtt_topic_prefix,
+                    state=b"closing",
+                )
                 report_device_info = update_device_info
         elif mqtt_message_payload.lower() == b"stop":
             if not self.__device.stop():
@@ -197,7 +230,11 @@ class _CurtainMotor(_MQTTControlledActor):
                 # no "stopped" state mentioned at
                 # https://www.home-assistant.io/integrations/cover.mqtt/#configuration-variables
                 # https://community.home-assistant.io/t/mqtt-how-to-remove-retained-messages/79029/2
-                self.report_state(mqtt_client=mqtt_client, state=b"")
+                self.report_state(
+                    mqtt_client=mqtt_client,
+                    mqtt_topic_prefix=mqtt_topic_prefix,
+                    state=b"",
+                )
                 report_device_info = update_device_info
                 report_position = True
         else:
@@ -207,7 +244,9 @@ class _CurtainMotor(_MQTTControlledActor):
             )
         if report_device_info:
             self._update_and_report_device_info(
-                mqtt_client=mqtt_client, report_position=report_position
+                mqtt_client=mqtt_client,
+                mqtt_topic_prefix=mqtt_topic_prefix,
+                report_position=report_position,
             )
 
     @classmethod
@@ -224,9 +263,9 @@ class _CurtainMotor(_MQTTControlledActor):
             _LOGGER.info("ignoring retained message on topic %s", message.topic)
             return
         actor = cls._init_from_topic(
-            userdata=userdata,
             topic=message.topic,
             expected_topic_levels=cls._MQTT_SET_POSITION_TOPIC_LEVELS,
+            settings=userdata,
         )
         if not actor:
             return  # warning in _init_from_topic

+ 47 - 20
switchbot_mqtt/_actors/_base.py

@@ -45,6 +45,9 @@ class _MQTTCallbackUserdata:
     retry_count: int
     device_passwords: typing.Dict[str, str]
     fetch_device_info: bool
+    # "homeassistant/" for historic reasons.
+    # will be parametrized via command-line argument in the future.
+    mqtt_topic_prefix: str = "homeassistant/"
 
 
 class _MQTTControlledActor(abc.ABC):
@@ -58,15 +61,17 @@ class _MQTTControlledActor(abc.ABC):
     ] = NotImplemented
 
     @classmethod
-    def get_mqtt_update_device_info_topic(cls, mac_address: str) -> str:
+    def get_mqtt_update_device_info_topic(cls, *, prefix: str, mac_address: str) -> str:
         return _join_mqtt_topic_levels(
+            topic_prefix=prefix,
             topic_levels=cls._MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS,
             mac_address=mac_address,
         )
 
     @classmethod
-    def get_mqtt_battery_percentage_topic(cls, mac_address: str) -> str:
+    def get_mqtt_battery_percentage_topic(cls, *, prefix: str, mac_address: str) -> str:
         return _join_mqtt_topic_levels(
+            topic_prefix=prefix,
             topic_levels=cls._MQTT_BATTERY_PERCENTAGE_TOPIC_LEVELS,
             mac_address=mac_address,
         )
@@ -121,31 +126,38 @@ class _MQTTControlledActor(abc.ABC):
                 ) from exc
             raise
 
-    def _report_battery_level(self, mqtt_client: paho.mqtt.client.Client) -> None:
+    def _report_battery_level(
+        self, mqtt_client: paho.mqtt.client.Client, mqtt_topic_prefix: str
+    ) -> None:
         # > battery: Percentage of battery that is left.
         # https://www.home-assistant.io/integrations/sensor/#device-class
         self._mqtt_publish(
+            topic_prefix=mqtt_topic_prefix,
             topic_levels=self._MQTT_BATTERY_PERCENTAGE_TOPIC_LEVELS,
             payload=str(self._get_device().get_battery_percent()).encode(),
             mqtt_client=mqtt_client,
         )
 
     def _update_and_report_device_info(
-        self, mqtt_client: paho.mqtt.client.Client
+        self, mqtt_client: paho.mqtt.client.Client, mqtt_topic_prefix: str
     ) -> None:
         self._update_device_info()
-        self._report_battery_level(mqtt_client=mqtt_client)
+        self._report_battery_level(
+            mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix
+        )
 
     @classmethod
     def _init_from_topic(
         cls,
-        userdata: _MQTTCallbackUserdata,
         topic: str,
         expected_topic_levels: typing.Collection[_MQTTTopicLevel],
+        settings: _MQTTCallbackUserdata,
     ) -> typing.Optional[_MQTTControlledActor]:
         try:
             mac_address = _parse_mqtt_topic(
-                topic=topic, expected_levels=expected_topic_levels
+                topic=topic,
+                expected_prefix=settings.mqtt_topic_prefix,
+                expected_levels=expected_topic_levels,
             )[_MQTTTopicPlaceholder.MAC_ADDRESS]
         except ValueError as exc:
             _LOGGER.warning(str(exc), exc_info=False)
@@ -155,8 +167,8 @@ class _MQTTControlledActor(abc.ABC):
             return None
         return cls(
             mac_address=mac_address,
-            retry_count=userdata.retry_count,
-            password=userdata.device_passwords.get(mac_address, None),
+            retry_count=settings.retry_count,
+            password=settings.device_passwords.get(mac_address, None),
         )
 
     @classmethod
@@ -173,20 +185,24 @@ class _MQTTControlledActor(abc.ABC):
             _LOGGER.info("ignoring retained message")
             return
         actor = cls._init_from_topic(
-            userdata=userdata,
             topic=message.topic,
             expected_topic_levels=cls._MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS,
+            settings=userdata,
         )
         if actor:
             # pylint: disable=protected-access; own instance
-            actor._update_and_report_device_info(mqtt_client)
+            actor._update_and_report_device_info(
+                mqtt_client=mqtt_client, mqtt_topic_prefix=userdata.mqtt_topic_prefix
+            )
 
     @abc.abstractmethod
     def execute_command(
         self,
+        *,
         mqtt_message_payload: bytes,
         mqtt_client: paho.mqtt.client.Client,
         update_device_info: bool,
+        mqtt_topic_prefix: str,
     ) -> None:
         raise NotImplementedError()
 
@@ -204,15 +220,16 @@ class _MQTTControlledActor(abc.ABC):
             _LOGGER.info("ignoring retained message")
             return
         actor = cls._init_from_topic(
-            userdata=userdata,
             topic=message.topic,
             expected_topic_levels=cls.MQTT_COMMAND_TOPIC_LEVELS,
+            settings=userdata,
         )
         if actor:
             actor.execute_command(
                 mqtt_message_payload=message.payload,
                 mqtt_client=mqtt_client,
                 update_device_info=userdata.fetch_device_info,
+                mqtt_topic_prefix=userdata.mqtt_topic_prefix,
             )
 
     @classmethod
@@ -234,15 +251,16 @@ class _MQTTControlledActor(abc.ABC):
 
     @classmethod
     def mqtt_subscribe(
-        cls,
-        mqtt_client: paho.mqtt.client.Client,
-        *,
-        enable_device_info_update_topic: bool,
+        cls, *, mqtt_client: paho.mqtt.client.Client, settings: _MQTTCallbackUserdata
     ) -> None:
         for topic_levels, callback in cls._get_mqtt_message_callbacks(
-            enable_device_info_update_topic=enable_device_info_update_topic
+            enable_device_info_update_topic=settings.fetch_device_info
         ).items():
-            topic = _join_mqtt_topic_levels(topic_levels, mac_address="+")
+            topic = _join_mqtt_topic_levels(
+                topic_prefix=settings.mqtt_topic_prefix,
+                topic_levels=topic_levels,
+                mac_address="+",
+            )
             _LOGGER.info("subscribing to MQTT topic %r", topic)
             mqtt_client.subscribe(topic)
             mqtt_client.message_callback_add(sub=topic, callback=callback)
@@ -250,12 +268,15 @@ class _MQTTControlledActor(abc.ABC):
     def _mqtt_publish(
         self,
         *,
+        topic_prefix: str,
         topic_levels: typing.Iterable[_MQTTTopicLevel],
         payload: bytes,
         mqtt_client: paho.mqtt.client.Client,
     ) -> None:
         topic = _join_mqtt_topic_levels(
-            topic_levels=topic_levels, mac_address=self._mac_address
+            topic_prefix=topic_prefix,
+            topic_levels=topic_levels,
+            mac_address=self._mac_address,
         )
         # https://pypi.org/project/paho-mqtt/#publishing
         _LOGGER.debug("publishing topic=%s payload=%r", topic, payload)
@@ -270,8 +291,14 @@ class _MQTTControlledActor(abc.ABC):
                 message_info.rc,
             )
 
-    def report_state(self, state: bytes, mqtt_client: paho.mqtt.client.Client) -> None:
+    def report_state(
+        self,
+        state: bytes,
+        mqtt_client: paho.mqtt.client.Client,
+        mqtt_topic_prefix: str,
+    ) -> None:
         self._mqtt_publish(
+            topic_prefix=mqtt_topic_prefix,
             topic_levels=self.MQTT_STATE_TOPIC_LEVELS,
             payload=state,
             mqtt_client=mqtt_client,

+ 27 - 11
switchbot_mqtt/_cli.py

@@ -26,6 +26,7 @@ import warnings
 import switchbot
 
 import switchbot_mqtt
+import switchbot_mqtt._actors._base  # rename {_->}base ?
 from switchbot_mqtt._actors import _ButtonAutomator, _CurtainMotor
 
 _MQTT_DEFAULT_PORT = 1883
@@ -80,22 +81,37 @@ def _main() -> None:
         help="Maximum number of attempts to send a command to a SwitchBot device"
         " (default: %(default)d)",
     )
+    _MQTTCallbackUserdata = (
+        switchbot_mqtt._actors._base._MQTTCallbackUserdata  # pylint: disable=protected-access; internal module & class
+    )
+    mqtt_topic_prefix = _MQTTCallbackUserdata.mqtt_topic_prefix
     argparser.add_argument(
         "--fetch-device-info",
         action="store_true",
-        help="Report devices' battery level on topic"
+        help="Report devices' battery level on topic "
         # pylint: disable=protected-access; internal
-        f" {_ButtonAutomator.get_mqtt_battery_percentage_topic(mac_address='MAC_ADDRESS')}"
-        " or, respectively,"
-        f" {_CurtainMotor.get_mqtt_battery_percentage_topic(mac_address='MAC_ADDRESS')}"
-        " after every command. Additionally report curtain motors' position on"
-        f" topic {_CurtainMotor.get_mqtt_position_topic(mac_address='MAC_ADDRESS')}"
-        " after executing stop commands."
+        + _ButtonAutomator.get_mqtt_battery_percentage_topic(
+            prefix=mqtt_topic_prefix, mac_address="MAC_ADDRESS"
+        )
+        + " or, respectively,"
+        + _CurtainMotor.get_mqtt_battery_percentage_topic(
+            prefix=mqtt_topic_prefix, mac_address="MAC_ADDRESS"
+        )
+        + " after every command. Additionally report curtain motors' position on topic "
+        + _CurtainMotor.get_mqtt_position_topic(
+            prefix=mqtt_topic_prefix, mac_address="MAC_ADDRESS"
+        )
+        + " after executing stop commands."
         " When this option is enabled, the mentioned reports may also be requested"
-        " by sending a MQTT message to the topic"
-        f" {_ButtonAutomator.get_mqtt_update_device_info_topic(mac_address='MAC_ADDRESS')}"
-        f" or {_CurtainMotor.get_mqtt_update_device_info_topic(mac_address='MAC_ADDRESS')}."
-        " This option can also be enabled by assigning a non-empty value to the"
+        " by sending a MQTT message to the topic "
+        + _ButtonAutomator.get_mqtt_update_device_info_topic(
+            prefix=mqtt_topic_prefix, mac_address="MAC_ADDRESS"
+        )
+        + " or "
+        + _CurtainMotor.get_mqtt_update_device_info_topic(
+            prefix=mqtt_topic_prefix, mac_address="MAC_ADDRESS"
+        )
+        + ". This option can also be enabled by assigning a non-empty value to the"
         " environment variable FETCH_DEVICE_INFO.",
     )
     argparser.add_argument("--debug", action="store_true")

+ 12 - 4
switchbot_mqtt/_utils.py

@@ -37,19 +37,27 @@ _MQTTTopicLevel = typing.Union[str, _MQTTTopicPlaceholder]
 
 
 def _join_mqtt_topic_levels(
-    topic_levels: typing.Iterable[_MQTTTopicLevel], mac_address: str
+    *,
+    topic_prefix: str,
+    topic_levels: typing.Iterable[_MQTTTopicLevel],
+    mac_address: str,
 ) -> str:
-    return "/".join(
+    return topic_prefix + "/".join(
         mac_address if l == _MQTTTopicPlaceholder.MAC_ADDRESS else typing.cast(str, l)
         for l in topic_levels
     )
 
 
 def _parse_mqtt_topic(
-    topic: str, expected_levels: typing.Collection[_MQTTTopicLevel]
+    *,
+    topic: str,
+    expected_prefix: str,
+    expected_levels: typing.Collection[_MQTTTopicLevel],
 ) -> typing.Dict[_MQTTTopicPlaceholder, str]:
+    if not topic.startswith(expected_prefix):
+        raise ValueError(f"expected topic prefix {expected_prefix}, got topic {topic}")
     attrs: typing.Dict[_MQTTTopicPlaceholder, str] = {}
-    topic_split = topic.split("/")
+    topic_split = topic[len(expected_prefix) :].split("/")
     if len(topic_split) != len(expected_levels):
         raise ValueError(f"unexpected topic {topic}")
     for given_part, expected_part in zip(topic_split, expected_levels):

+ 8 - 1
tests/test_actor_base.py

@@ -46,14 +46,18 @@ def test_execute_command_abstract() -> None:
 
         def execute_command(
             self,
+            *,
             mqtt_message_payload: bytes,
             mqtt_client: paho.mqtt.client.Client,
             update_device_info: bool,
+            mqtt_topic_prefix: str,
         ) -> None:
+            assert 21
             super().execute_command(
                 mqtt_message_payload=mqtt_message_payload,
                 mqtt_client=mqtt_client,
                 update_device_info=update_device_info,
+                mqtt_topic_prefix=mqtt_topic_prefix,
             )
 
         def _get_device(self) -> switchbot.SwitchbotDevice:
@@ -63,7 +67,10 @@ def test_execute_command_abstract() -> None:
     actor = _ActorMock(mac_address="aa:bb:cc:dd:ee:ff", retry_count=42, password=None)
     with pytest.raises(NotImplementedError):
         actor.execute_command(
-            mqtt_message_payload=b"dummy", mqtt_client="dummy", update_device_info=True
+            mqtt_message_payload=b"dummy",
+            mqtt_client="dummy",
+            update_device_info=True,
+            mqtt_topic_prefix="whatever",
         )
     with pytest.raises(NotImplementedError):
         actor._get_device()

+ 37 - 6
tests/test_mqtt.py

@@ -223,9 +223,11 @@ def _mock_actor_class(
 
         def execute_command(
             self,
+            *,
             mqtt_message_payload: bytes,
             mqtt_client: Client,
             update_device_info: bool,
+            mqtt_topic_prefix: str,
         ) -> None:
             pass
 
@@ -274,7 +276,9 @@ def test__mqtt_update_device_info_callback(
     init_mock.assert_called_once_with(
         mac_address=expected_mac_address, retry_count=21, password=None
     )
-    update_mock.assert_called_once_with("client_dummy")
+    update_mock.assert_called_once_with(
+        mqtt_client="client_dummy", mqtt_topic_prefix="homeassistant/"
+    )
     assert caplog.record_tuples == [
         (
             "switchbot_mqtt._actors._base",
@@ -320,45 +324,58 @@ def test__mqtt_update_device_info_callback_ignore_retained(
 
 
 @pytest.mark.parametrize(
-    ("command_topic_levels", "topic", "payload", "expected_mac_address"),
+    (
+        "topic_prefix",
+        "command_topic_levels",
+        "topic",
+        "payload",
+        "expected_mac_address",
+    ),
     [
         (
+            "homeassistant/",
             _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS,
             b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set",
             b"ON",
             "aa:bb:cc:dd:ee:ff",
         ),
         (
+            "homeassistant/",
             _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS,
             b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set",
             b"OFF",
             "aa:bb:cc:dd:ee:ff",
         ),
         (
+            "homeassistant/",
             _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS,
             b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set",
             b"on",
             "aa:bb:cc:dd:ee:ff",
         ),
         (
+            "homeassistant/",
             _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS,
             b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set",
             b"off",
             "aa:bb:cc:dd:ee:ff",
         ),
         (
+            "prefix-",
             _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS,
-            b"homeassistant/switch/switchbot/aa:01:23:45:67:89/set",
+            b"prefix-switch/switchbot/aa:01:23:45:67:89/set",
             b"ON",
             "aa:01:23:45:67:89",
         ),
         (
+            "",
             ["switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS],
             b"switchbot/aa:01:23:45:67:89",
             b"ON",
             "aa:01:23:45:67:89",
         ),
         (
+            "homeassistant/",
             _CurtainMotor.MQTT_COMMAND_TOPIC_LEVELS,
             b"homeassistant/cover/switchbot-curtain/aa:01:23:45:67:89/set",
             b"OPEN",
@@ -370,6 +387,7 @@ def test__mqtt_update_device_info_callback_ignore_retained(
 @pytest.mark.parametrize("fetch_device_info", [True, False])
 def test__mqtt_command_callback(
     caplog: _pytest.logging.LogCaptureFixture,
+    topic_prefix: str,
     command_topic_levels: typing.Tuple[_MQTTTopicLevel, ...],
     topic: bytes,
     payload: bytes,
@@ -384,6 +402,7 @@ def test__mqtt_command_callback(
         retry_count=retry_count,
         device_passwords={},
         fetch_device_info=fetch_device_info,
+        mqtt_topic_prefix=topic_prefix,
     )
     with unittest.mock.patch.object(
         ActorMock, "__init__", return_value=None
@@ -400,6 +419,7 @@ def test__mqtt_command_callback(
         mqtt_client="client_dummy",
         mqtt_message_payload=payload,
         update_device_info=fetch_device_info,
+        mqtt_topic_prefix=topic_prefix,
     )
     assert caplog.record_tuples == [
         (
@@ -424,7 +444,7 @@ def test__mqtt_command_callback_password(
     ActorMock = _mock_actor_class(
         command_topic_levels=("switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS)
     )
-    message = MQTTMessage(topic=b"switchbot/" + mac_address.encode())
+    message = MQTTMessage(topic=b"prefix-switchbot/" + mac_address.encode())
     message.payload = b"whatever"
     callback_userdata = _MQTTCallbackUserdata(
         retry_count=3,
@@ -434,6 +454,7 @@ def test__mqtt_command_callback_password(
             "11:22:33:dd:ee:ff": "äöü",
         },
         fetch_device_info=True,
+        mqtt_topic_prefix="prefix-",
     )
     with unittest.mock.patch.object(
         ActorMock, "__init__", return_value=None
@@ -448,6 +469,7 @@ def test__mqtt_command_callback_password(
         mqtt_client="client_dummy",
         mqtt_message_payload=b"whatever",
         update_device_info=True,
+        mqtt_topic_prefix="prefix-",
     )
 
 
@@ -577,15 +599,17 @@ def test__mqtt_command_callback_ignore_retained(
 
 
 @pytest.mark.parametrize(
-    ("state_topic_levels", "mac_address", "expected_topic"),
+    ("topic_prefix", "state_topic_levels", "mac_address", "expected_topic"),
     # https://www.home-assistant.io/docs/mqtt/discovery/#switches
     [
         (
+            "homeassistant/",
             _ButtonAutomator.MQTT_STATE_TOPIC_LEVELS,
             "aa:bb:cc:dd:ee:ff",
             "homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/state",
         ),
         (
+            "",
             ["switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS, "state"],
             "aa:bb:cc:dd:ee:gg",
             "switchbot/aa:bb:cc:dd:ee:gg/state",
@@ -596,6 +620,7 @@ def test__mqtt_command_callback_ignore_retained(
 @pytest.mark.parametrize("return_code", [MQTT_ERR_SUCCESS, MQTT_ERR_QUEUE_SIZE])
 def test__report_state(
     caplog: _pytest.logging.LogCaptureFixture,
+    topic_prefix: str,
     state_topic_levels: typing.Tuple[_MQTTTopicLevel, ...],
     mac_address: str,
     expected_topic: str,
@@ -615,9 +640,11 @@ def test__report_state(
 
         def execute_command(
             self,
+            *,
             mqtt_message_payload: bytes,
             mqtt_client: Client,
             update_device_info: bool,
+            mqtt_topic_prefix: str,
         ) -> None:
             pass
 
@@ -628,7 +655,11 @@ def test__report_state(
     mqtt_client_mock.publish.return_value.rc = return_code
     with caplog.at_level(logging.DEBUG):
         actor = _ActorMock(mac_address=mac_address, retry_count=3, password=None)
-        actor.report_state(state=state, mqtt_client=mqtt_client_mock)
+        actor.report_state(
+            state=state,
+            mqtt_client=mqtt_client_mock,
+            mqtt_topic_prefix=topic_prefix,
+        )
     mqtt_client_mock.publish.assert_called_once_with(
         topic=expected_topic, payload=state, retain=True
     )

+ 20 - 7
tests/test_switchbot_button_automator.py

@@ -30,32 +30,39 @@ from switchbot_mqtt._actors import _ButtonAutomator
 # pylint: disable=too-many-arguments; these are tests, no API
 
 
+@pytest.mark.parametrize("prefix", ["homeassistant/", "prefix-", ""])
 @pytest.mark.parametrize("mac_address", ["{MAC_ADDRESS}", "aa:bb:cc:dd:ee:ff"])
-def test_get_mqtt_battery_percentage_topic(mac_address: str) -> None:
+def test_get_mqtt_battery_percentage_topic(prefix: str, mac_address: str) -> None:
     assert (
-        _ButtonAutomator.get_mqtt_battery_percentage_topic(mac_address=mac_address)
-        == f"homeassistant/switch/switchbot/{mac_address}/battery-percentage"
+        _ButtonAutomator.get_mqtt_battery_percentage_topic(
+            prefix=prefix, mac_address=mac_address
+        )
+        == f"{prefix}switch/switchbot/{mac_address}/battery-percentage"
     )
 
 
+@pytest.mark.parametrize("topic_prefix", ["homeassistant/", "prefix-", ""])
 @pytest.mark.parametrize(("battery_percent", "battery_percent_encoded"), [(42, b"42")])
 def test__update_and_report_device_info(
-    battery_percent: int, battery_percent_encoded: bytes
+    topic_prefix: str, battery_percent: int, battery_percent_encoded: bytes
 ) -> None:
     with unittest.mock.patch("switchbot.SwitchbotCurtain.__init__", return_value=None):
         actor = _ButtonAutomator(mac_address="dummy", retry_count=21, password=None)
     actor._get_device()._switchbot_device_data = {"data": {"battery": battery_percent}}
     mqtt_client_mock = unittest.mock.MagicMock()
     with unittest.mock.patch("switchbot.Switchbot.update") as update_mock:
-        actor._update_and_report_device_info(mqtt_client=mqtt_client_mock)
+        actor._update_and_report_device_info(
+            mqtt_client=mqtt_client_mock, mqtt_topic_prefix=topic_prefix
+        )
     update_mock.assert_called_once_with()
     mqtt_client_mock.publish.assert_called_once_with(
-        topic="homeassistant/switch/switchbot/dummy/battery-percentage",
+        topic=f"{topic_prefix}switch/switchbot/dummy/battery-percentage",
         payload=battery_percent_encoded,
         retain=True,
     )
 
 
+@pytest.mark.parametrize("topic_prefix", ["homeassistant/"])
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"])
 @pytest.mark.parametrize("password", (None, "secret"))
 @pytest.mark.parametrize("retry_count", (3, 21))
@@ -74,6 +81,7 @@ def test__update_and_report_device_info(
 @pytest.mark.parametrize("command_successful", [True, False])
 def test_execute_command(
     caplog: _pytest.logging.LogCaptureFixture,
+    topic_prefix: str,
     mac_address: str,
     password: typing.Optional[str],
     retry_count: int,
@@ -99,6 +107,7 @@ def test_execute_command(
                 mqtt_client="dummy",
                 mqtt_message_payload=message_payload,
                 update_device_info=update_device_info,
+                mqtt_topic_prefix=topic_prefix,
             )
     device_init_mock.assert_called_once_with(
         mac=mac_address, password=password, retry_count=retry_count
@@ -113,7 +122,9 @@ def test_execute_command(
             )
         ]
         report_mock.assert_called_once_with(
-            mqtt_client="dummy", state=message_payload.upper()
+            mqtt_client="dummy",
+            mqtt_topic_prefix=topic_prefix,
+            state=message_payload.upper(),
         )
         assert update_device_info_mock.call_count == (1 if update_device_info else 0)
     else:
@@ -142,6 +153,7 @@ def test_execute_command_invalid_payload(
                 mqtt_client="dummy",
                 mqtt_message_payload=message_payload,
                 update_device_info=True,
+                mqtt_topic_prefix="dummy",
             )
     device_mock.assert_called_once_with(mac=mac_address, retry_count=21, password=None)
     assert not device_mock().mock_calls  # no methods called
@@ -178,6 +190,7 @@ def test_execute_command_bluetooth_error(
             mqtt_client="dummy",
             mqtt_message_payload=message_payload,
             update_device_info=True,
+            mqtt_topic_prefix="dummy",
         )
     assert len(caplog.records) == 2
     assert caplog.records[0].name == "switchbot"

+ 25 - 10
tests/test_switchbot_curtain_motor.py

@@ -34,7 +34,9 @@ from switchbot_mqtt._actors import _CurtainMotor
 @pytest.mark.parametrize("mac_address", ["{MAC_ADDRESS}", "aa:bb:cc:dd:ee:ff"])
 def test_get_mqtt_battery_percentage_topic(mac_address: str) -> None:
     assert (
-        _CurtainMotor.get_mqtt_battery_percentage_topic(mac_address=mac_address)
+        _CurtainMotor.get_mqtt_battery_percentage_topic(
+            prefix="homeassistant/", mac_address=mac_address
+        )
         == f"homeassistant/cover/switchbot-curtain/{mac_address}/battery-percentage"
     )
 
@@ -42,8 +44,8 @@ def test_get_mqtt_battery_percentage_topic(mac_address: str) -> None:
 @pytest.mark.parametrize("mac_address", ["{MAC_ADDRESS}", "aa:bb:cc:dd:ee:ff"])
 def test_get_mqtt_position_topic(mac_address: str) -> None:
     assert (
-        _CurtainMotor.get_mqtt_position_topic(mac_address=mac_address)
-        == f"homeassistant/cover/switchbot-curtain/{mac_address}/position"
+        _CurtainMotor.get_mqtt_position_topic(prefix="prfx-", mac_address=mac_address)
+        == f"prfx-cover/switchbot-curtain/{mac_address}/position"
     )
 
 
@@ -80,10 +82,10 @@ def test__report_position(
     ) as publish_mock, unittest.mock.patch(
         "switchbot.SwitchbotCurtain.get_position", return_value=position
     ):
-        actor._report_position(mqtt_client="dummy")
+        actor._report_position(mqtt_client="dummy", mqtt_topic_prefix="topic-prefix")
     publish_mock.assert_called_once_with(
+        topic_prefix="topic-prefix",
         topic_levels=(
-            "homeassistant",
             "cover",
             "switchbot-curtain",
             switchbot_mqtt._utils._MQTTTopicPlaceholder.MAC_ADDRESS,
@@ -112,14 +114,16 @@ def test__report_position_invalid(
     ), pytest.raises(
         ValueError
     ):
-        actor._report_position(mqtt_client="dummy")
+        actor._report_position(mqtt_client="dummy", mqtt_topic_prefix="dummy2")
     publish_mock.assert_not_called()
 
 
+@pytest.mark.parametrize("topic_prefix", ["", "homeassistant/"])
 @pytest.mark.parametrize(("battery_percent", "battery_percent_encoded"), [(42, b"42")])
 @pytest.mark.parametrize("report_position", [True, False])
 @pytest.mark.parametrize(("position", "position_encoded"), [(21, b"21")])
 def test__update_and_report_device_info(
+    topic_prefix: str,
     report_position: bool,
     battery_percent: int,
     battery_percent_encoded: bytes,
@@ -134,13 +138,15 @@ def test__update_and_report_device_info(
     mqtt_client_mock = unittest.mock.MagicMock()
     with unittest.mock.patch("switchbot.SwitchbotCurtain.update") as update_mock:
         actor._update_and_report_device_info(
-            mqtt_client=mqtt_client_mock, report_position=report_position
+            mqtt_client=mqtt_client_mock,
+            mqtt_topic_prefix=topic_prefix,
+            report_position=report_position,
         )
     update_mock.assert_called_once_with()
     assert mqtt_client_mock.publish.call_count == (1 + report_position)
     assert (
         unittest.mock.call(
-            topic="homeassistant/cover/switchbot-curtain/dummy/battery-percentage",
+            topic=topic_prefix + "cover/switchbot-curtain/dummy/battery-percentage",
             payload=battery_percent_encoded,
             retain=True,
         )
@@ -149,7 +155,7 @@ def test__update_and_report_device_info(
     if report_position:
         assert (
             unittest.mock.call(
-                topic="homeassistant/cover/switchbot-curtain/dummy/position",
+                topic=topic_prefix + "cover/switchbot-curtain/dummy/position",
                 payload=position_encoded,
                 retain=True,
             )
@@ -170,10 +176,13 @@ def test__update_and_report_device_info_update_error(exception: Exception) -> No
     with unittest.mock.patch.object(
         actor._get_device(), "update", side_effect=exception
     ), pytest.raises(type(exception)):
-        actor._update_and_report_device_info(mqtt_client_mock, report_position=True)
+        actor._update_and_report_device_info(
+            mqtt_client_mock, mqtt_topic_prefix="dummy", report_position=True
+        )
     mqtt_client_mock.publish.assert_not_called()
 
 
+@pytest.mark.parametrize("topic_prefix", ["topic-prfx"])
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"])
 @pytest.mark.parametrize("password", ["pa$$word", None])
 @pytest.mark.parametrize("retry_count", (2, 3))
@@ -195,6 +204,7 @@ def test__update_and_report_device_info_update_error(exception: Exception) -> No
 @pytest.mark.parametrize("command_successful", [True, False])
 def test_execute_command(
     caplog: _pytest.logging.LogCaptureFixture,
+    topic_prefix: str,
     mac_address: str,
     password: typing.Optional[str],
     retry_count: int,
@@ -220,6 +230,7 @@ def test_execute_command(
                 mqtt_client="dummy",
                 mqtt_message_payload=message_payload,
                 update_device_info=update_device_info,
+                mqtt_topic_prefix=topic_prefix,
             )
     device_init_mock.assert_called_once_with(
         mac=mac_address, password=password, retry_count=retry_count, reverse_mode=True
@@ -238,6 +249,7 @@ def test_execute_command(
         ]
         report_mock.assert_called_once_with(
             mqtt_client="dummy",
+            mqtt_topic_prefix=topic_prefix,
             # https://www.home-assistant.io/integrations/cover.mqtt/#state_opening
             state={b"open": b"opening", b"close": b"closing", b"stop": b""}[
                 message_payload.lower()
@@ -256,6 +268,7 @@ def test_execute_command(
         update_device_info_mock.assert_called_once_with(
             mqtt_client="dummy",
             report_position=(action_name == "switchbot.SwitchbotCurtain.stop"),
+            mqtt_topic_prefix=topic_prefix,
         )
     else:
         update_device_info_mock.assert_not_called()
@@ -279,6 +292,7 @@ def test_execute_command_invalid_payload(
                 mqtt_client="dummy",
                 mqtt_message_payload=message_payload,
                 update_device_info=True,
+                mqtt_topic_prefix="dummy",
             )
     device_mock.assert_called_once_with(
         mac=mac_address, password=password, retry_count=7, reverse_mode=True
@@ -317,6 +331,7 @@ def test_execute_command_bluetooth_error(
             mqtt_client="dummy",
             mqtt_message_payload=message_payload,
             update_device_info=True,
+            mqtt_topic_prefix="dummy",
         )
     assert len(caplog.records) == 2
     assert caplog.records[0].name == "switchbot"

+ 1 - 0
tests/test_switchbot_curtain_motor_position.py

@@ -134,6 +134,7 @@ def test__mqtt_set_position_callback_unexpected_topic(
         retry_count=3,
         device_passwords={},
         fetch_device_info=False,
+        mqtt_topic_prefix="",
     )
     message = MQTTMessage(topic=b"switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set")
     message.payload = b"42"

+ 48 - 8
tests/test_utils.py

@@ -44,60 +44,100 @@ def test__mac_address_valid(mac_address: str, valid: bool) -> None:
 
 
 @pytest.mark.parametrize(
-    ("expected_levels", "topic", "expected_attrs"),
+    ("expected_prefix", "expected_levels", "topic", "expected_attrs"),
     [
         (
+            "",
             ["switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS, "set"],
             "switchbot/aa:bb:cc:dd:ee:ff/set",
             {_MQTTTopicPlaceholder.MAC_ADDRESS: "aa:bb:cc:dd:ee:ff"},
         ),
         (
+            "",
             ["switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS, "set"],
             "switchbot//set",
             {_MQTTTopicPlaceholder.MAC_ADDRESS: ""},
         ),
         (
+            "",
             ["prefix", _MQTTTopicPlaceholder.MAC_ADDRESS],
-            "prefix/aa:bb:cc:dd:ee:ff",
-            {_MQTTTopicPlaceholder.MAC_ADDRESS: "aa:bb:cc:dd:ee:ff"},
+            "prefix/aa:bb:cc:dd:ee:f1",
+            {_MQTTTopicPlaceholder.MAC_ADDRESS: "aa:bb:cc:dd:ee:f1"},
         ),
         (
+            "",
             [_MQTTTopicPlaceholder.MAC_ADDRESS],
             "00:11:22:33:44:55",
             {_MQTTTopicPlaceholder.MAC_ADDRESS: "00:11:22:33:44:55"},
         ),
+        (
+            "prefix/",
+            [_MQTTTopicPlaceholder.MAC_ADDRESS],
+            "prefix/aa:bb:cc:dd:ee:f2",
+            {_MQTTTopicPlaceholder.MAC_ADDRESS: "aa:bb:cc:dd:ee:f2"},
+        ),
+        (
+            "prefix-",
+            ["test", _MQTTTopicPlaceholder.MAC_ADDRESS, "42"],
+            "prefix-test/aa:bb:cc:dd:ee:f3/42",
+            {_MQTTTopicPlaceholder.MAC_ADDRESS: "aa:bb:cc:dd:ee:f3"},
+        ),
     ],
 )
 def test__parse_mqtt_topic(
+    expected_prefix: str,
     expected_levels: typing.List[_MQTTTopicLevel],
     topic: str,
     expected_attrs: typing.Dict[_MQTTTopicPlaceholder, str],
 ) -> None:
     assert (
-        _parse_mqtt_topic(topic=topic, expected_levels=expected_levels)
+        _parse_mqtt_topic(
+            topic=topic,
+            expected_prefix=expected_prefix,
+            expected_levels=expected_levels,
+        )
         == expected_attrs
     )
 
 
+def test__parse_mqtt_topic_unexpected_prefix() -> None:
+    with pytest.raises(
+        ValueError,
+        match=r"^expected topic prefix abcdefg/, got topic abcdef/aa:bb:cc:dd:ee:ff$",
+    ):
+        _parse_mqtt_topic(
+            topic="abcdef/aa:bb:cc:dd:ee:ff",
+            expected_prefix="abcdefg/",
+            expected_levels=[_MQTTTopicPlaceholder.MAC_ADDRESS],
+        )
+
+
 @pytest.mark.parametrize(
-    ("expected_levels", "topic"),
+    ("expected_prefix", "expected_levels", "topic"),
     [
         (
+            "",
             ["switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS, "set"],
             "switchbot/aa:bb:cc:dd:ee:ff",
         ),
         (
+            "",
             ["switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS, "set"],
             "switchbot/aa:bb:cc:dd:ee:ff/change",
         ),
         (
+            "prfx",
             ["switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS, "set"],
-            "switchbot/aa:bb:cc:dd:ee:ff/set/suffix",
+            "prfx/switchbot/aa:bb:cc:dd:ee:ff/set/suffix",
         ),
     ],
 )
 def test__parse_mqtt_topic_fail(
-    expected_levels: typing.List[_MQTTTopicLevel], topic: str
+    expected_prefix: str, expected_levels: typing.List[_MQTTTopicLevel], topic: str
 ) -> None:
     with pytest.raises(ValueError):
-        _parse_mqtt_topic(topic=topic, expected_levels=expected_levels)
+        _parse_mqtt_topic(
+            topic=topic,
+            expected_prefix=expected_prefix,
+            expected_levels=expected_levels,
+        )