Browse Source

refactor: replaced paho.mqtt.client.Client.on_message with specific .message_callback_add

Fabian Peter Hammerle 3 years ago
parent
commit
6fc954ba94
4 changed files with 119 additions and 100 deletions
  1. 49 33
      systemctl_mqtt/__init__.py
  2. 2 2
      tests/test_dbus.py
  3. 67 56
      tests/test_mqtt.py
  4. 1 9
      tests/test_settings.py

+ 49 - 33
systemctl_mqtt/__init__.py

@@ -77,18 +77,49 @@ def _schedule_shutdown(action: str) -> None:
             _LOGGER.error("failed to schedule %s: %s", action, exc_msg)
 
 
-_MQTT_TOPIC_SUFFIX_ACTION_MAPPING = {
-    "poweroff": functools.partial(_schedule_shutdown, action="poweroff"),
-}
-
-
 class _Settings:
+
     # pylint: disable=too-few-public-methods
+
     def __init__(self, mqtt_topic_prefix: str) -> None:
-        self.mqtt_topic_action_mapping = {}  # type: typing.Dict[str, typing.Callable]
-        for topic_suffix, action in _MQTT_TOPIC_SUFFIX_ACTION_MAPPING.items():
-            topic = mqtt_topic_prefix + "/" + topic_suffix
-            self.mqtt_topic_action_mapping[topic] = action
+        self._mqtt_topic_prefix = mqtt_topic_prefix
+
+    @property
+    def mqtt_topic_prefix(self) -> str:
+        return self._mqtt_topic_prefix
+
+
+class _MQTTAction:
+
+    # pylint: disable=too-few-public-methods
+
+    def __init__(self, name: str, action: typing.Callable) -> None:
+        self.name = name
+        self.action = action
+
+    def mqtt_message_callback(
+        self,
+        mqtt_client: paho.mqtt.client.Client,
+        settings: _Settings,
+        message: paho.mqtt.client.MQTTMessage,
+    ) -> None:
+        # pylint: disable=unused-argument; callback
+        # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L3416
+        # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L469
+        _LOGGER.debug("received topic=%s payload=%r", message.topic, message.payload)
+        if message.retain:
+            _LOGGER.info("ignoring retained message")
+            return
+        _LOGGER.debug("executing action %s (%r)", self.name, self.action)
+        self.action()
+        _LOGGER.debug("completed action %s (%r)", self.name, self.action)
+
+
+_MQTT_TOPIC_SUFFIX_ACTION_MAPPING = {
+    "poweroff": _MQTTAction(
+        name="poweroff", action=functools.partial(_schedule_shutdown, action="poweroff")
+    ),
+}
 
 
 def _mqtt_on_connect(
@@ -102,30 +133,16 @@ def _mqtt_on_connect(
     assert return_code == 0, return_code  # connection accepted
     mqtt_broker_host, mqtt_broker_port = mqtt_client.socket().getpeername()
     _LOGGER.debug("connected to MQTT broker %s:%d", mqtt_broker_host, mqtt_broker_port)
-    for topic in settings.mqtt_topic_action_mapping.keys():
-        _LOGGER.debug("subscribing to %s", topic)
+    for topic_suffix, action in _MQTT_TOPIC_SUFFIX_ACTION_MAPPING.items():
+        topic = settings.mqtt_topic_prefix + "/" + topic_suffix
+        _LOGGER.info("subscribing to %s", topic)
         mqtt_client.subscribe(topic)
-
-
-def _mqtt_on_message(
-    mqtt_client: paho.mqtt.client.Client,
-    settings: _Settings,
-    message: paho.mqtt.client.MQTTMessage,
-) -> None:
-    # pylint: disable=unused-argument; callback
-    # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L469
-    _LOGGER.debug("received topic=%s payload=%r", message.topic, message.payload)
-    if message.retain:
-        _LOGGER.info("ignoring retained message")
-        return
-    try:
-        action = settings.mqtt_topic_action_mapping[message.topic]
-    except KeyError:
-        _LOGGER.warning("unexpected topic %s", message.topic)
-        return
-    _LOGGER.debug("executing action %r", action)
-    action()
-    _LOGGER.debug("completed action %r", action)
+        mqtt_client.message_callback_add(
+            sub=topic, callback=action.mqtt_message_callback
+        )
+        _LOGGER.debug(
+            "registered MQTT callback for topic %s triggering %r", topic, action.action
+        )
 
 
 def _run(
@@ -140,7 +157,6 @@ def _run(
         userdata=_Settings(mqtt_topic_prefix=mqtt_topic_prefix)
     )
     mqtt_client.on_connect = _mqtt_on_connect
-    mqtt_client.on_message = _mqtt_on_message
     mqtt_client.tls_set(ca_certs=None)  # enable tls trusting default system certs
     _LOGGER.info(
         "connecting to MQTT broker %s:%d", mqtt_host, mqtt_port,

+ 2 - 2
tests/test_dbus.py

@@ -92,12 +92,12 @@ def test__schedule_shutdown_fail(caplog, action, exception_message, log_message)
     ("topic_suffix", "expected_action_arg"), [("poweroff", "poweroff")]
 )
 def test_mqtt_topic_suffix_action_mapping(topic_suffix, expected_action_arg):
-    action = systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[topic_suffix]
+    mqtt_action = systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[topic_suffix]
     login_manager_mock = unittest.mock.MagicMock()
     with unittest.mock.patch(
         "systemctl_mqtt._get_login_manager", return_value=login_manager_mock
     ):
-        action()
+        mqtt_action.action()
     assert login_manager_mock.ScheduleShutdown.call_count == 1
     schedule_args, schedule_kwargs = login_manager_mock.ScheduleShutdown.call_args
     assert len(schedule_args) == 2

+ 67 - 56
tests/test_mqtt.py

@@ -29,16 +29,15 @@ import systemctl_mqtt
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
 @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
-def test__run(mqtt_host, mqtt_port, mqtt_topic_prefix):
+def test__run(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
+    caplog.set_level(logging.DEBUG)
     with unittest.mock.patch(
         "socket.create_connection"
     ) as create_socket_mock, unittest.mock.patch(
         "ssl.SSLContext.wrap_socket", autospec=True,
     ) as ssl_wrap_socket_mock, unittest.mock.patch(
         "paho.mqtt.client.Client.loop_forever", autospec=True,
-    ) as mqtt_loop_forever_mock, unittest.mock.patch(
-        "systemctl_mqtt._mqtt_on_message"
-    ) as message_handler_mock:
+    ) as mqtt_loop_forever_mock:
         ssl_wrap_socket_mock.return_value.send = len
         systemctl_mqtt._run(
             mqtt_host=mqtt_host,
@@ -47,6 +46,10 @@ def test__run(mqtt_host, mqtt_port, mqtt_topic_prefix):
             mqtt_password=None,
             mqtt_topic_prefix=mqtt_topic_prefix,
         )
+    assert caplog.records[0].levelno == logging.INFO
+    assert caplog.records[0].message == "connecting to MQTT broker {}:{}".format(
+        mqtt_host, mqtt_port
+    )
     # correct remote?
     assert create_socket_mock.call_count == 1
     create_socket_args, _ = create_socket_mock.call_args
@@ -64,19 +67,48 @@ def test__run(mqtt_host, mqtt_port, mqtt_topic_prefix):
     assert mqtt_client._username is None
     assert mqtt_client._password is None
     # connect callback
+    caplog.clear()
     mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
     with unittest.mock.patch(
         "paho.mqtt.client.Client.subscribe"
     ) as mqtt_subscribe_mock:
         mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
     mqtt_subscribe_mock.assert_called_once_with(mqtt_topic_prefix + "/poweroff")
+    assert mqtt_client.on_message is None
+    assert (  # pylint: disable=comparison-with-callable
+        mqtt_client._on_message_filtered[mqtt_topic_prefix + "/poweroff"]
+        == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
+            "poweroff"
+        ].mqtt_message_callback
+    )
+    assert caplog.records[0].levelno == logging.DEBUG
+    assert caplog.records[0].message == "connected to MQTT broker {}:{}".format(
+        mqtt_host, mqtt_port
+    )
+    assert caplog.records[1].levelno == logging.INFO
+    assert caplog.records[1].message == "subscribing to {}".format(
+        mqtt_topic_prefix + "/poweroff"
+    )
+    assert caplog.records[2].levelno == logging.DEBUG
+    assert caplog.records[2].message == "registered MQTT callback for topic {}".format(
+        mqtt_topic_prefix + "/poweroff"
+    ) + " triggering {}".format(
+        systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"].action
+    )
     # message callback
-    test_message = MQTTMessage(topic=b"test")
-    message_handler_mock.assert_not_called()
-    mqtt_client._handle_on_message(test_message)
-    message_handler_mock.assert_called_once_with(
-        mqtt_client, mqtt_client._userdata, test_message
+    caplog.clear()
+    poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
+    with unittest.mock.patch.object(
+        systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
+    ) as poweroff_action_mock:
+        mqtt_client._handle_on_message(poweroff_message)
+    poweroff_action_mock.assert_called_once_with()
+    assert all(r.levelno == logging.DEBUG for r in caplog.records)
+    assert caplog.records[0].message == "received topic={} payload=b''".format(
+        poweroff_message.topic
     )
+    assert caplog.records[1].message.startswith("executing action poweroff")
+    assert caplog.records[2].message.startswith("completed action poweroff")
 
 
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
@@ -124,19 +156,20 @@ def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_passwor
             )
 
 
-@pytest.mark.parametrize("mqtt_topic_prefix", ["system/command"])
+@pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
 @pytest.mark.parametrize("payload", [b"", b"junk"])
-def test__mqtt_on_message_poweroff(caplog, mqtt_topic_prefix: str, payload: bytes):
-    mqtt_topic = mqtt_topic_prefix + "/poweroff"
+def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
     message = MQTTMessage(topic=mqtt_topic.encode())
     message.payload = payload
-    settings = systemctl_mqtt._Settings(mqtt_topic_prefix=mqtt_topic_prefix)
-    action_mock = unittest.mock.MagicMock()
-    settings.mqtt_topic_action_mapping[mqtt_topic] = action_mock  # functools.partial
-    with caplog.at_level(logging.DEBUG):
-        systemctl_mqtt._mqtt_on_message(
-            None, settings, message,
+    with unittest.mock.patch.object(
+        systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
+    ) as action_mock, caplog.at_level(logging.DEBUG):
+        systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
+            "poweroff"
+        ].mqtt_message_callback(
+            None, None, message  # type: ignore
         )
+    action_mock.assert_called_once_with()
     assert len(caplog.records) == 3
     assert caplog.records[0].levelno == logging.DEBUG
     assert caplog.records[0].message == (
@@ -144,57 +177,35 @@ def test__mqtt_on_message_poweroff(caplog, mqtt_topic_prefix: str, payload: byte
     )
     assert caplog.records[1].levelno == logging.DEBUG
     assert caplog.records[1].message.startswith(
-        "executing action {!r}".format(action_mock)
+        "executing action {} ({!r})".format("poweroff", action_mock)
     )
     assert caplog.records[2].levelno == logging.DEBUG
     assert caplog.records[2].message.startswith(
-        "completed action {!r}".format(action_mock)
+        "completed action {} ({!r})".format("poweroff", action_mock)
     )
-    action_mock.assert_called_once_with()
 
 
-@pytest.mark.parametrize(
-    ("topic", "payload"), [("system/poweroff", b""), ("system/poweroff", "payload"),],
-)
-def test__mqtt_on_message_ignored(
-    caplog, topic: str, payload: bytes,
-):
-    message = MQTTMessage(topic=topic.encode())
-    message.payload = payload
-    settings = systemctl_mqtt._Settings(mqtt_topic_prefix="system/command")
-    settings.mqtt_topic_action_mapping = {}  # provoke KeyError on access
-    with caplog.at_level(logging.DEBUG):
-        systemctl_mqtt._mqtt_on_message(
-            None, settings, message,
-        )
-    assert len(caplog.records) == 2
-    assert caplog.records[0].levelno == logging.DEBUG
-    assert caplog.records[0].message == (
-        "received topic={} payload={!r}".format(topic, payload)
-    )
-    assert caplog.records[1].levelno == logging.WARNING
-    assert caplog.records[1].message == "unexpected topic {}".format(topic)
-
-
-@pytest.mark.parametrize(
-    ("topic", "payload"), [("system/command/poweroff", b"")],
-)
-def test__mqtt_on_message_ignored_retained(
-    caplog, topic: str, payload: bytes,
+@pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
+@pytest.mark.parametrize("payload", [b"", b"junk"])
+def test_mqtt_message_callback_poweroff_retained(
+    caplog, mqtt_topic: str, payload: bytes
 ):
-    message = MQTTMessage(topic=topic.encode())
+    message = MQTTMessage(topic=mqtt_topic.encode())
     message.payload = payload
     message.retain = True
-    settings = systemctl_mqtt._Settings(mqtt_topic_prefix="system/command")
-    settings.mqtt_topic_action_mapping = {}  # provoke KeyError on access
-    with caplog.at_level(logging.DEBUG):
-        systemctl_mqtt._mqtt_on_message(
-            None, settings, message,
+    with unittest.mock.patch.object(
+        systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
+    ) as action_mock, caplog.at_level(logging.DEBUG):
+        systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
+            "poweroff"
+        ].mqtt_message_callback(
+            None, None, message  # type: ignore
         )
+    action_mock.assert_not_called()
     assert len(caplog.records) == 2
     assert caplog.records[0].levelno == logging.DEBUG
     assert caplog.records[0].message == (
-        "received topic={} payload={!r}".format(topic, payload)
+        "received topic={} payload={!r}".format(mqtt_topic, payload)
     )
     assert caplog.records[1].levelno == logging.INFO
     assert caplog.records[1].message == "ignoring retained message"

+ 1 - 9
tests/test_settings.py

@@ -15,8 +15,6 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-import functools
-
 import pytest
 
 import systemctl_mqtt
@@ -27,10 +25,4 @@ import systemctl_mqtt
 @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
 def test_mqtt_topic_action_mapping(mqtt_topic_prefix):
     settings = systemctl_mqtt._Settings(mqtt_topic_prefix=mqtt_topic_prefix)
-    assert len(settings.mqtt_topic_action_mapping) == 1
-    action = settings.mqtt_topic_action_mapping[mqtt_topic_prefix + "/poweroff"]
-    assert isinstance(action, functools.partial)
-    # pylint: disable=comparison-with-callable
-    assert action.func == systemctl_mqtt._schedule_shutdown
-    assert not action.args
-    assert action.keywords == {"action": "poweroff"}
+    assert settings.mqtt_topic_prefix == mqtt_topic_prefix