Browse Source

refactor: make `_run` function async (to prepare for migration from paho-mqtt to aiomqtt)

Fabian Peter Hammerle 2 months ago
parent
commit
ffac98894a
5 changed files with 61 additions and 34 deletions
  1. 1 0
      Pipfile
  2. 10 1
      Pipfile.lock
  3. 14 11
      systemctl_mqtt/__init__.py
  4. 4 3
      tests/test_dbus.py
  5. 32 19
      tests/test_mqtt.py

+ 1 - 0
Pipfile

@@ -11,6 +11,7 @@ black = "*"
 mypy = "*"
 pylint = "*"
 pytest = "*"
+pytest-asyncio = "*"
 pytest-cov = "*"
 
 # python<3.11 compatibility

+ 10 - 1
Pipfile.lock

@@ -1,7 +1,7 @@
 {
     "_meta": {
         "hash": {
-            "sha256": "006c2a13bf3537443cf16ece65026026f7f514bbfc1f984c5ef197cd4a03f68f"
+            "sha256": "915d7f16e1258ec66edd3c4f4fd972f563e23fc0e9f8f56eabd41bcb76c56821"
         },
         "pipfile-spec": 6,
         "requires": {
@@ -290,6 +290,15 @@
             "markers": "python_version >= '3.8'",
             "version": "==8.3.4"
         },
+        "pytest-asyncio": {
+            "hashes": [
+                "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609",
+                "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3"
+            ],
+            "index": "pypi",
+            "markers": "python_version >= '3.9'",
+            "version": "==0.25.0"
+        },
         "pytest-cov": {
             "hashes": [
                 "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35",

+ 14 - 11
systemctl_mqtt/__init__.py

@@ -17,6 +17,7 @@
 
 import abc
 import argparse
+import asyncio
 import datetime
 import functools
 import importlib.metadata
@@ -291,7 +292,7 @@ def _mqtt_on_connect(
         )
 
 
-def _run(  # pylint: disable=too-many-arguments
+async def _run(  # pylint: disable=too-many-arguments
     *,
     mqtt_host: str,
     mqtt_port: int,
@@ -439,14 +440,16 @@ def _main() -> None:
             f" {systemctl_mqtt._homeassistant.NODE_ID_ALLOWED_CHARS})"
             "\nchange --homeassistant-discovery-object-id"
         )
-    _run(
-        mqtt_host=args.mqtt_host,
-        mqtt_port=mqtt_port,
-        mqtt_disable_tls=args.mqtt_disable_tls,
-        mqtt_username=args.mqtt_username,
-        mqtt_password=mqtt_password,
-        mqtt_topic_prefix=args.mqtt_topic_prefix,
-        homeassistant_discovery_prefix=args.homeassistant_discovery_prefix,
-        homeassistant_discovery_object_id=args.homeassistant_discovery_object_id,
-        poweroff_delay=datetime.timedelta(seconds=args.poweroff_delay_seconds),
+    asyncio.run(
+        _run(
+            mqtt_host=args.mqtt_host,
+            mqtt_port=mqtt_port,
+            mqtt_disable_tls=args.mqtt_disable_tls,
+            mqtt_username=args.mqtt_username,
+            mqtt_password=mqtt_password,
+            mqtt_topic_prefix=args.mqtt_topic_prefix,
+            homeassistant_discovery_prefix=args.homeassistant_discovery_prefix,
+            homeassistant_discovery_object_id=args.homeassistant_discovery_object_id,
+            poweroff_delay=datetime.timedelta(seconds=args.poweroff_delay_seconds),
+        )
     )

+ 4 - 3
tests/test_dbus.py

@@ -176,7 +176,8 @@ def test_lock_all_sessions(caplog):
     assert caplog.records[0].message == "instruct all sessions to activate screen locks"
 
 
-def test__run_signal_loop():
+@pytest.mark.asyncio
+async def test__run_signal_loop():
     # pylint: disable=too-many-locals,too-many-arguments
     login_manager_mock = unittest.mock.MagicMock()
     dbus_connection_mock = unittest.mock.MagicMock()
@@ -196,8 +197,8 @@ def test__run_signal_loop():
             jeepney.low_level.Message(header=None, body=(False,)),
         ]
         login_manager_mock.Inhibit.return_value = (jeepney.fds.FileDescriptor(-1),)
-        with pytest.raises(StopIteration):
-            systemctl_mqtt._run(
+        with pytest.raises(RuntimeError, match=r"^coroutine raised StopIteration$"):
+            await systemctl_mqtt._run(
                 mqtt_host="localhost",
                 mqtt_port=1833,
                 mqtt_username=None,

+ 32 - 19
tests/test_mqtt.py

@@ -44,12 +44,13 @@ def mock_open_dbus_connection() -> typing.Iterator[unittest.mock.MagicMock]:
         yield mock
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
 @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
 @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
 @pytest.mark.parametrize("homeassistant_discovery_object_id", ["host", "node"])
-def test__run(
+async def test__run(
     caplog,
     mqtt_host,
     mqtt_port,
@@ -72,8 +73,8 @@ def test__run(
         ssl_wrap_socket_mock.return_value.send = len
         login_manager_mock.Inhibit.return_value = (jeepney.fds.FileDescriptor(-1),)
         login_manager_mock.Get.return_value = (("b", False),)
-        with pytest.raises(StopIteration):
-            systemctl_mqtt._run(
+        with pytest.raises(RuntimeError, match=r"^coroutine raised StopIteration$"):
+            await systemctl_mqtt._run(
                 mqtt_host=mqtt_host,
                 mqtt_port=mqtt_port,
                 mqtt_username=None,
@@ -171,15 +172,18 @@ def test__run(
     assert mqtt_client._thread is None
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
 @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
-def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
+async def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
     caplog.set_level(logging.INFO)
     with unittest.mock.patch(
         "paho.mqtt.client.Client"
-    ) as mqtt_client_class, mock_open_dbus_connection(), pytest.raises(StopIteration):
-        systemctl_mqtt._run(
+    ) as mqtt_client_class, mock_open_dbus_connection(), pytest.raises(
+        RuntimeError, match=r"^coroutine raised StopIteration$"
+    ):
+        await systemctl_mqtt._run(
             mqtt_host=mqtt_host,
             mqtt_port=mqtt_port,
             mqtt_disable_tls=mqtt_disable_tls,
@@ -201,11 +205,14 @@ def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
         mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
 
 
-def test__run_tls_default():
+@pytest.mark.asyncio
+async def test__run_tls_default():
     with unittest.mock.patch(
         "paho.mqtt.client.Client"
-    ) as mqtt_client_class, mock_open_dbus_connection(), pytest.raises(StopIteration):
-        systemctl_mqtt._run(
+    ) as mqtt_client_class, mock_open_dbus_connection(), pytest.raises(
+        RuntimeError, match=r"^coroutine raised StopIteration$"
+    ):
+        await systemctl_mqtt._run(
             mqtt_host="mqtt-broker.local",
             mqtt_port=1833,
             # mqtt_disable_tls default,
@@ -220,12 +227,13 @@ def test__run_tls_default():
     mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
 @pytest.mark.parametrize("mqtt_username", ["me"])
 @pytest.mark.parametrize("mqtt_password", [None, "secret"])
 @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
-def test__run_authentication(
+async def test__run_authentication(
     mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
 ):
     with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
@@ -235,10 +243,10 @@ def test__run_authentication(
     ) as mqtt_loop_forever_mock, unittest.mock.patch(
         "systemctl_mqtt._dbus.get_login_manager_proxy"
     ), mock_open_dbus_connection(), pytest.raises(
-        StopIteration
+        RuntimeError, match=r"^coroutine raised StopIteration$"
     ):
         ssl_wrap_socket_mock.return_value.send = len
-        systemctl_mqtt._run(
+        await systemctl_mqtt._run(
             mqtt_host=mqtt_host,
             mqtt_port=mqtt_port,
             mqtt_username=mqtt_username,
@@ -257,7 +265,8 @@ def test__run_authentication(
         assert mqtt_client._password is None
 
 
-def _initialize_mqtt_client(
+@pytest.mark.asyncio
+async def _initialize_mqtt_client(
     mqtt_host, mqtt_port, mqtt_topic_prefix
 ) -> paho.mqtt.client.Client:
     with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
@@ -267,14 +276,14 @@ def _initialize_mqtt_client(
     ) as mqtt_loop_forever_mock, unittest.mock.patch(
         "systemctl_mqtt._dbus.get_login_manager_proxy"
     ) as get_login_manager_mock, mock_open_dbus_connection(), pytest.raises(
-        StopIteration
+        RuntimeError, match=r"^coroutine raised StopIteration$"
     ):
         ssl_wrap_socket_mock.return_value.send = len
         get_login_manager_mock.return_value.Inhibit.return_value = (
             jeepney.fds.FileDescriptor(-1),
         )
         get_login_manager_mock.return_value.Get.return_value = (("b", True),)
-        systemctl_mqtt._run(
+        await systemctl_mqtt._run(
             mqtt_host=mqtt_host,
             mqtt_port=mqtt_port,
             mqtt_username=None,
@@ -293,11 +302,12 @@ def _initialize_mqtt_client(
     return mqtt_client
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
 @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
-def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
-    mqtt_client = _initialize_mqtt_client(
+async def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
+    mqtt_client = await _initialize_mqtt_client(
         mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
     )
     caplog.clear()
@@ -317,15 +327,18 @@ def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix)
     assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
 
 
+@pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
 @pytest.mark.parametrize("mqtt_password", ["secret"])
-def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
+async def test__run_authentication_missing_username(
+    mqtt_host, mqtt_port, mqtt_password
+):
     with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
         "systemctl_mqtt._dbus.get_login_manager_proxy"
     ), mock_open_dbus_connection():
         with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
-            systemctl_mqtt._run(
+            await systemctl_mqtt._run(
                 mqtt_host=mqtt_host,
                 mqtt_port=mqtt_port,
                 mqtt_username=None,