Browse Source

migrate to jeepney's asyncio integration (to prepare for migration from paho-mqtt to aiomqtt)

https://github.com/fphammerle/systemctl-mqtt/commit/ffac98894a5b321578dc79d5af91cd3c93f0e212
Fabian Peter Hammerle 3 months ago
parent
commit
4a4cc551d4
4 changed files with 125 additions and 96 deletions
  1. 27 17
      systemctl_mqtt/__init__.py
  2. 37 40
      tests/test_dbus.py
  3. 30 39
      tests/test_mqtt.py
  4. 31 0
      tests/test_state_dbus.py

+ 27 - 17
systemctl_mqtt/__init__.py

@@ -31,7 +31,7 @@ import typing
 
 import jeepney
 import jeepney.bus_messages
-import jeepney.io.blocking
+import jeepney.io.asyncio
 import paho.mqtt.client
 
 import systemctl_mqtt._dbus
@@ -292,6 +292,31 @@ def _mqtt_on_connect(
         )
 
 
+async def _dbus_signal_loop(
+    *, state: _State, mqtt_client: paho.mqtt.client.Client
+) -> None:
+    async with jeepney.io.asyncio.open_dbus_router(bus="SYSTEM") as router:
+        # router: jeepney.io.asyncio.DBusRouter
+        bus_proxy = jeepney.io.asyncio.Proxy(
+            msggen=jeepney.bus_messages.message_bus, router=router
+        )
+        preparing_for_shutdown_match_rule = (
+            # pylint: disable=protected-access
+            systemctl_mqtt._dbus.get_login_manager_signal_match_rule(
+                "PrepareForShutdown"
+            )
+        )
+        assert await bus_proxy.AddMatch(preparing_for_shutdown_match_rule) == ()
+        with router.filter(preparing_for_shutdown_match_rule) as queue:
+            while True:
+                message: jeepney.low_level.Message = await queue.get()
+                (preparing_for_shutdown,) = message.body
+                state.preparing_for_shutdown_handler(
+                    active=preparing_for_shutdown, mqtt_client=mqtt_client
+                )
+                queue.task_done()
+
+
 async def _run(  # pylint: disable=too-many-arguments
     *,
     mqtt_host: str,
@@ -304,16 +329,6 @@ async def _run(  # pylint: disable=too-many-arguments
     poweroff_delay: datetime.timedelta,
     mqtt_disable_tls: bool = False,
 ) -> None:
-    # pylint: disable=too-many-locals; will be split up when switching to async mqtt
-    dbus_connection = jeepney.io.blocking.open_dbus_connection(bus="SYSTEM")
-    bus_proxy = jeepney.io.blocking.Proxy(
-        msggen=jeepney.bus_messages.message_bus, connection=dbus_connection
-    )
-    preparing_for_shutdown_match_rule = (
-        # pylint: disable=protected-access
-        systemctl_mqtt._dbus.get_login_manager_signal_match_rule("PrepareForShutdown")
-    )
-    assert bus_proxy.AddMatch(preparing_for_shutdown_match_rule) == ()
     state = _State(
         mqtt_topic_prefix=mqtt_topic_prefix,
         homeassistant_discovery_prefix=homeassistant_discovery_prefix,
@@ -342,12 +357,7 @@ async def _run(  # pylint: disable=too-many-arguments
     # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1744
     mqtt_client.loop_start()
     try:
-        with dbus_connection.filter(preparing_for_shutdown_match_rule) as queue:
-            while True:
-                (preparing_for_sleep,) = dbus_connection.recv_until_filtered(queue).body
-                state.preparing_for_shutdown_handler(
-                    active=preparing_for_sleep, mqtt_client=mqtt_client
-                )
+        await _dbus_signal_loop(state=state, mqtt_client=mqtt_client)
     finally:
         # blocks until loop_forever stops
         _LOGGER.debug("waiting for MQTT loop to stop")

+ 37 - 40
tests/test_dbus.py

@@ -15,6 +15,7 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
+import asyncio
 import datetime
 import logging
 import typing
@@ -177,40 +178,44 @@ def test_lock_all_sessions(caplog):
 
 
 @pytest.mark.asyncio
-async def test__run_signal_loop():
+async def test__dbus_signal_loop():
     # pylint: disable=too-many-locals,too-many-arguments
-    login_manager_mock = unittest.mock.MagicMock()
-    dbus_connection_mock = unittest.mock.MagicMock()
+    state_mock = unittest.mock.MagicMock()
     with unittest.mock.patch(
-        "paho.mqtt.client.Client"
-    ) as mqtt_client_mock, unittest.mock.patch(
-        "systemctl_mqtt._dbus.get_login_manager_proxy", return_value=login_manager_mock
-    ), unittest.mock.patch(
-        "jeepney.io.blocking.open_dbus_connection", return_value=dbus_connection_mock
-    ) as open_dbus_connection_mock:
+        "jeepney.io.asyncio.open_dbus_router",
+    ) as open_dbus_router_mock:
+        async with open_dbus_router_mock() as dbus_router_mock:
+            pass
         add_match_reply = unittest.mock.Mock()
         add_match_reply.body = ()
-        dbus_connection_mock.send_and_get_reply.return_value = add_match_reply
-        dbus_connection_mock.recv_until_filtered.side_effect = [
-            jeepney.low_level.Message(header=None, body=(False,)),
-            jeepney.low_level.Message(header=None, body=(True,)),
-            jeepney.low_level.Message(header=None, body=(False,)),
-        ]
-        login_manager_mock.Inhibit.return_value = (jeepney.fds.FileDescriptor(-1),)
-        with pytest.raises(RuntimeError, match=r"^coroutine raised StopIteration$"):
-            await systemctl_mqtt._run(
-                mqtt_host="localhost",
-                mqtt_port=1833,
-                mqtt_username=None,
-                mqtt_password=None,
-                mqtt_topic_prefix="systemctl/host",
-                homeassistant_discovery_prefix="homeassistant",
-                homeassistant_discovery_object_id="test",
-                poweroff_delay=datetime.timedelta(),
+        dbus_router_mock.send_and_get_reply.return_value = add_match_reply
+        msg_queue = asyncio.Queue()
+        await msg_queue.put(jeepney.low_level.Message(header=None, body=(False,)))
+        await msg_queue.put(jeepney.low_level.Message(header=None, body=(True,)))
+        await msg_queue.put(jeepney.low_level.Message(header=None, body=(False,)))
+        dbus_router_mock.filter = unittest.mock.MagicMock()
+        dbus_router_mock.filter.return_value.__enter__.return_value = msg_queue
+        # asyncio.TaskGroup added in python3.11
+        loop_task = asyncio.create_task(
+            systemctl_mqtt._dbus_signal_loop(
+                state=state_mock, mqtt_client=unittest.mock.MagicMock()
             )
-    open_dbus_connection_mock.assert_called_once_with(bus="SYSTEM")
-    dbus_connection_mock.send_and_get_reply.assert_called_once()
-    add_match_msg = dbus_connection_mock.send_and_get_reply.call_args[0][0]
+        )
+
+        async def _abort_after_msg_queue():
+            await msg_queue.join()
+            loop_task.cancel()
+
+        with pytest.raises(asyncio.exceptions.CancelledError):
+            await asyncio.gather(*(loop_task, _abort_after_msg_queue()))
+    assert unittest.mock.call(bus="SYSTEM") in open_dbus_router_mock.call_args_list
+    dbus_router_mock.filter.assert_called_once()
+    (filter_match_rule,) = dbus_router_mock.filter.call_args[0]
+    assert (
+        filter_match_rule.header_fields["interface"] == "org.freedesktop.login1.Manager"
+    )
+    assert filter_match_rule.header_fields["member"] == "PrepareForShutdown"
+    add_match_msg = dbus_router_mock.send_and_get_reply.call_args[0][0]
     assert (
         add_match_msg.header.fields[jeepney.low_level.HeaderFields.member] == "AddMatch"
     )
@@ -218,14 +223,6 @@ async def test__run_signal_loop():
         "interface='org.freedesktop.login1.Manager',member='PrepareForShutdown'"
         ",path='/org/freedesktop/login1',type='signal'",
     )
-    assert mqtt_client_mock().publish.call_args_list == [
-        unittest.mock.call(
-            topic="systemctl/host/preparing-for-shutdown", payload="false", retain=True
-        ),
-        unittest.mock.call(
-            topic="systemctl/host/preparing-for-shutdown", payload="true", retain=True
-        ),
-        unittest.mock.call(
-            topic="systemctl/host/preparing-for-shutdown", payload="false", retain=True
-        ),
-    ]
+    assert [
+        c[1]["active"] for c in state_mock.preparing_for_shutdown_handler.call_args_list
+    ] == [False, True, False]

+ 30 - 39
tests/test_mqtt.py

@@ -15,12 +15,10 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-import contextlib
 import datetime
 import logging
 import threading
 import time
-import typing
 import unittest.mock
 
 import jeepney.fds
@@ -34,16 +32,6 @@ import systemctl_mqtt
 # pylint: disable=protected-access,too-many-positional-arguments
 
 
-@contextlib.contextmanager
-def mock_open_dbus_connection() -> typing.Iterator[unittest.mock.MagicMock]:
-    with unittest.mock.patch("jeepney.io.blocking.open_dbus_connection") as mock:
-        add_match_reply = unittest.mock.Mock()
-        add_match_reply.body = ()
-        mock.return_value.send_and_get_reply.return_value = add_match_reply
-        mock.return_value.recv_until_filtered.side_effect = []
-        yield mock
-
-
 @pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
@@ -69,21 +57,22 @@ async def test__run(
         "paho.mqtt.client.Client.loop_forever", autospec=True
     ) as mqtt_loop_forever_mock, unittest.mock.patch(
         "systemctl_mqtt._dbus.get_login_manager_proxy", return_value=login_manager_mock
-    ), mock_open_dbus_connection() as open_dbus_connection_mock:
+    ), unittest.mock.patch(
+        "systemctl_mqtt._dbus_signal_loop"
+    ) as dbus_signal_loop_mock:
         ssl_wrap_socket_mock.return_value.send = len
         login_manager_mock.Inhibit.return_value = (jeepney.fds.FileDescriptor(-1),)
         login_manager_mock.Get.return_value = (("b", False),)
-        with pytest.raises(RuntimeError, match=r"^coroutine raised StopIteration$"):
-            await systemctl_mqtt._run(
-                mqtt_host=mqtt_host,
-                mqtt_port=mqtt_port,
-                mqtt_username=None,
-                mqtt_password=None,
-                mqtt_topic_prefix=mqtt_topic_prefix,
-                homeassistant_discovery_prefix=homeassistant_discovery_prefix,
-                homeassistant_discovery_object_id=homeassistant_discovery_object_id,
-                poweroff_delay=datetime.timedelta(),
-            )
+        await systemctl_mqtt._run(
+            mqtt_host=mqtt_host,
+            mqtt_port=mqtt_port,
+            mqtt_username=None,
+            mqtt_password=None,
+            mqtt_topic_prefix=mqtt_topic_prefix,
+            homeassistant_discovery_prefix=homeassistant_discovery_prefix,
+            homeassistant_discovery_object_id=homeassistant_discovery_object_id,
+            poweroff_delay=datetime.timedelta(),
+        )
     assert caplog.records[0].levelno == logging.INFO
     assert caplog.records[0].message == (
         f"connecting to MQTT broker {mqtt_host}:{mqtt_port} (TLS enabled)"
@@ -113,7 +102,6 @@ async def test__run(
         "paho.mqtt.client.Client.subscribe"
     ) as mqtt_subscribe_mock:
         mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
-    open_dbus_connection_mock.assert_called_once_with(bus="SYSTEM")
     login_manager_mock.Inhibit.assert_called_once_with(
         what="shutdown",
         who="systemctl-mqtt",
@@ -121,7 +109,6 @@ async def test__run(
         mode="delay",
     )
     login_manager_mock.Get.assert_called_once_with("PreparingForShutdown")
-    open_dbus_connection_mock.return_value.send_and_get_reply.assert_called_once()
     assert sorted(mqtt_subscribe_mock.call_args_list) == [
         unittest.mock.call(mqtt_topic_prefix + "/lock-all-sessions"),
         unittest.mock.call(mqtt_topic_prefix + "/poweroff"),
@@ -166,7 +153,7 @@ async def test__run(
         f" triggering {systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[s]}"
         for s in ("poweroff", "lock-all-sessions", "suspend")
     }
-    open_dbus_connection_mock.return_value.filter.assert_called_once()
+    dbus_signal_loop_mock.assert_awaited_once()
     # waited for mqtt loop to stop?
     assert mqtt_client._thread_terminate
     assert mqtt_client._thread is None
@@ -180,9 +167,9 @@ async def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
     caplog.set_level(logging.INFO)
     with unittest.mock.patch(
         "paho.mqtt.client.Client"
-    ) as mqtt_client_class, mock_open_dbus_connection(), pytest.raises(
-        RuntimeError, match=r"^coroutine raised StopIteration$"
-    ):
+    ) as mqtt_client_class, unittest.mock.patch(
+        "systemctl_mqtt._dbus_signal_loop"
+    ) as dbus_signal_loop_mock:
         await systemctl_mqtt._run(
             mqtt_host=mqtt_host,
             mqtt_port=mqtt_port,
@@ -203,15 +190,16 @@ async def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
         mqtt_client_class().tls_set.assert_not_called()
     else:
         mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
+    dbus_signal_loop_mock.assert_awaited_once()
 
 
 @pytest.mark.asyncio
 async def test__run_tls_default():
     with unittest.mock.patch(
         "paho.mqtt.client.Client"
-    ) as mqtt_client_class, mock_open_dbus_connection(), pytest.raises(
-        RuntimeError, match=r"^coroutine raised StopIteration$"
-    ):
+    ) as mqtt_client_class, unittest.mock.patch(
+        "systemctl_mqtt._dbus_signal_loop"
+    ) as dbus_signal_loop_mock:
         await systemctl_mqtt._run(
             mqtt_host="mqtt-broker.local",
             mqtt_port=1833,
@@ -225,6 +213,7 @@ async def test__run_tls_default():
         )
     # enabled by default
     mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
+    dbus_signal_loop_mock.assert_awaited_once()
 
 
 @pytest.mark.asyncio
@@ -242,9 +231,9 @@ async def test__run_authentication(
         "paho.mqtt.client.Client.loop_forever", autospec=True
     ) as mqtt_loop_forever_mock, unittest.mock.patch(
         "systemctl_mqtt._dbus.get_login_manager_proxy"
-    ), mock_open_dbus_connection(), pytest.raises(
-        RuntimeError, match=r"^coroutine raised StopIteration$"
-    ):
+    ), unittest.mock.patch(
+        "systemctl_mqtt._dbus_signal_loop"
+    ) as dbus_signal_loop_mock:
         ssl_wrap_socket_mock.return_value.send = len
         await systemctl_mqtt._run(
             mqtt_host=mqtt_host,
@@ -263,6 +252,7 @@ async def test__run_authentication(
         assert mqtt_client._password.decode() == mqtt_password
     else:
         assert mqtt_client._password is None
+    dbus_signal_loop_mock.assert_awaited_once()
 
 
 @pytest.mark.asyncio
@@ -275,8 +265,8 @@ async def _initialize_mqtt_client(
         "paho.mqtt.client.Client.loop_forever", autospec=True
     ) as mqtt_loop_forever_mock, unittest.mock.patch(
         "systemctl_mqtt._dbus.get_login_manager_proxy"
-    ) as get_login_manager_mock, mock_open_dbus_connection(), pytest.raises(
-        RuntimeError, match=r"^coroutine raised StopIteration$"
+    ) as get_login_manager_mock, unittest.mock.patch(
+        "systemctl_mqtt._dbus_signal_loop"
     ):
         ssl_wrap_socket_mock.return_value.send = len
         get_login_manager_mock.return_value.Inhibit.return_value = (
@@ -336,7 +326,7 @@ async def test__run_authentication_missing_username(
 ):
     with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
         "systemctl_mqtt._dbus.get_login_manager_proxy"
-    ), mock_open_dbus_connection():
+    ), unittest.mock.patch("systemctl_mqtt._dbus_signal_loop") as dbus_signal_loop_mock:
         with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
             await systemctl_mqtt._run(
                 mqtt_host=mqtt_host,
@@ -348,6 +338,7 @@ async def test__run_authentication_missing_username(
                 homeassistant_discovery_object_id="node-id",
                 poweroff_delay=datetime.timedelta(),
             )
+    dbus_signal_loop_mock.assert_not_called()
 
 
 @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])

+ 31 - 0
tests/test_state_dbus.py

@@ -55,6 +55,37 @@ def test_shutdown_lock():
     lock_fd.close.assert_called_once_with()
 
 
+@pytest.mark.parametrize("active", [True, False])
+def test_preparing_for_shutdown_handler(active: bool) -> None:
+    with unittest.mock.patch("systemctl_mqtt._dbus.get_login_manager_proxy"):
+        state = systemctl_mqtt._State(
+            mqtt_topic_prefix="any",
+            homeassistant_discovery_prefix="pre/fix",
+            homeassistant_discovery_object_id="obj",
+            poweroff_delay=datetime.timedelta(),
+        )
+    mqtt_client_mock = unittest.mock.MagicMock()
+    with unittest.mock.patch.object(
+        state, "_publish_preparing_for_shutdown"
+    ) as publish_mock, unittest.mock.patch.object(
+        state, "acquire_shutdown_lock"
+    ) as acquire_lock_mock, unittest.mock.patch.object(
+        state, "release_shutdown_lock"
+    ) as release_lock_mock:
+        state.preparing_for_shutdown_handler(
+            active=active, mqtt_client=mqtt_client_mock
+        )
+    publish_mock.assert_called_once_with(
+        mqtt_client=mqtt_client_mock, active=active, block=True
+    )
+    if active:
+        acquire_lock_mock.assert_not_called()
+        release_lock_mock.assert_called_once_with()
+    else:
+        acquire_lock_mock.assert_called_once_with()
+        release_lock_mock.assert_not_called()
+
+
 @pytest.mark.parametrize("active", [True, False])
 def test_publish_preparing_for_shutdown(active: bool) -> None:
     login_manager_mock = unittest.mock.MagicMock()