Browse Source

Allow for unit start and stop (#223)

https://github.com/fphammerle/systemctl-mqtt/pull/223
Florian Eitel 1 month ago
parent
commit
4942421de2

+ 2 - 0
CHANGELOG.md

@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 ## [Unreleased]
 ## Added
 - support jeepney v0.9
+- ability to start/stop system units using `--control_system_unit <unit_name>`
+  ([#223](https://github.com/fphammerle/systemctl-mqtt/pull/223)
 
 ## [1.1.0] - 2025-01-19
 ### Added

+ 57 - 5
systemctl_mqtt/__init__.py

@@ -93,6 +93,12 @@ class _State:
     def get_system_unit_active_state_mqtt_topic(self, *, unit_name: str) -> str:
         return self._mqtt_topic_prefix + "/unit/system/" + unit_name + "/active-state"
 
+    def get_system_unit_start_mqtt_topic(self, *, unit_name: str) -> str:
+        return self._mqtt_topic_prefix + "/unit/system/" + unit_name + "/start"
+
+    def get_system_unit_stop_mqtt_topic(self, *, unit_name: str) -> str:
+        return self._mqtt_topic_prefix + "/unit/system/" + unit_name + "/stop"
+
     def get_system_unit_restart_mqtt_topic(self, *, unit_name: str) -> str:
         return self._mqtt_topic_prefix + "/unit/system/" + unit_name + "/restart"
 
@@ -232,6 +238,24 @@ class _State:
                 ),
             }
         for unit_name in self._controlled_system_unit_names:
+            config["components"]["unit/system/" + unit_name + "/start"] = {  # type: ignore
+                "unique_id": f"{unique_id_prefix}-unit-system-{unit_name}-start",
+                "object_id": f"{hostname}_unit_system_{unit_name}_start",
+                "name": f"{unit_name} start",
+                "platform": "button",
+                "command_topic": self.get_system_unit_start_mqtt_topic(
+                    unit_name=unit_name
+                ),
+            }
+            config["components"]["unit/system/" + unit_name + "/stop"] = {  # type: ignore
+                "unique_id": f"{unique_id_prefix}-unit-system-{unit_name}-stop",
+                "object_id": f"{hostname}_unit_system_{unit_name}_stop",
+                "name": f"{unit_name} stop",
+                "platform": "button",
+                "command_topic": self.get_system_unit_stop_mqtt_topic(
+                    unit_name=unit_name
+                ),
+            }
             config["components"]["unit/system/" + unit_name + "/restart"] = {  # type: ignore
                 "unique_id": f"{unique_id_prefix}-unit-system-{unit_name}-restart",
                 "object_id": f"{hostname}_unit_system_{unit_name}_restart",
@@ -265,6 +289,24 @@ class _MQTTActionSchedulePoweroff(_MQTTAction):
         )
 
 
+class _MQTTActionStartUnit(_MQTTAction):
+    # pylint: disable=protected-access,too-few-public-methods
+    def __init__(self, unit_name: str):
+        self._unit_name = unit_name
+
+    def trigger(self, state: _State) -> None:
+        systemctl_mqtt._dbus.service_manager.start_unit(unit_name=self._unit_name)
+
+
+class _MQTTActionStopUnit(_MQTTAction):
+    # pylint: disable=protected-access,too-few-public-methods
+    def __init__(self, unit_name: str):
+        self._unit_name = unit_name
+
+    def trigger(self, state: _State) -> None:
+        systemctl_mqtt._dbus.service_manager.stop_unit(unit_name=self._unit_name)
+
+
 class _MQTTActionRestartUnit(_MQTTAction):
     # pylint: disable=protected-access,too-few-public-methods
     def __init__(self, unit_name: str):
@@ -304,11 +346,21 @@ async def _mqtt_message_loop(*, state: _State, mqtt_client: aiomqtt.Client) -> N
         action_by_topic[topic] = action
 
     for unit_name in state.controlled_system_unit_names:
-        topic = state.mqtt_topic_prefix + "/unit/system/" + unit_name + "/restart"
-        _LOGGER.info("subscribing to %s", topic)
-        await mqtt_client.subscribe(topic)
-        action = _MQTTActionRestartUnit(unit_name=unit_name)
-        action_by_topic[topic] = action
+        for topic_suffix, action_class in [
+            ("start", _MQTTActionStartUnit),
+            ("stop", _MQTTActionStopUnit),
+            ("restart", _MQTTActionRestartUnit),
+        ]:
+            topic = (
+                state.mqtt_topic_prefix
+                + "/unit/system/"
+                + unit_name
+                + "/"
+                + topic_suffix
+            )
+            _LOGGER.info("subscribing to %s", topic)
+            await mqtt_client.subscribe(topic)
+            action_by_topic[topic] = action_class(unit_name=unit_name)
 
     async for message in mqtt_client.messages:
         if message.retain:

+ 42 - 0
systemctl_mqtt/_dbus/service_manager.py

@@ -43,6 +43,28 @@ class ServiceManager(jeepney.MessageGenerator):
             remote_obj=self, method="GetUnit", signature="s", body=(name,)
         )
 
+    def StartUnit(self, name: str, mode: str) -> jeepney.low_level.Message:
+        return jeepney.new_method_call(
+            remote_obj=self,
+            method="StartUnit",
+            signature="ss",
+            body=(
+                name,
+                mode,
+            ),
+        )
+
+    def StopUnit(self, name: str, mode: str) -> jeepney.low_level.Message:
+        return jeepney.new_method_call(
+            remote_obj=self,
+            method="StopUnit",
+            signature="ss",
+            body=(
+                name,
+                mode,
+            ),
+        )
+
     def RestartUnit(self, name: str, mode: str) -> jeepney.low_level.Message:
         return jeepney.new_method_call(
             remote_obj=self,
@@ -70,6 +92,26 @@ class Unit(systemctl_mqtt._dbus.Properties):  # pylint: disable=protected-access
     # pylint: disable=invalid-name
 
 
+def start_unit(unit_name: str):
+    proxy = get_service_manager_proxy()
+    try:
+        proxy.StartUnit(unit_name, "replace")
+        _LOGGER.debug("Starting unit: %s", unit_name)
+    # pylint: disable=broad-exception-caught
+    except jeepney.wrappers.DBusErrorResponse as exc:
+        _LOGGER.error("Failed to start unit: %s because %s ", unit_name, exc.name)
+
+
+def stop_unit(unit_name: str):
+    proxy = get_service_manager_proxy()
+    try:
+        proxy.StopUnit(unit_name, "replace")
+        _LOGGER.debug("Stopping unit: %s", unit_name)
+    # pylint: disable=broad-exception-caught
+    except jeepney.wrappers.DBusErrorResponse as exc:
+        _LOGGER.error("Failed to stop unit: %s because %s ", unit_name, exc.name)
+
+
 def restart_unit(unit_name: str):
     proxy = get_service_manager_proxy()
     try:

+ 42 - 12
tests/dbus/message-generators/test_service_manager.py

@@ -65,42 +65,72 @@ async def test__get_unit_path() -> None:
     assert not send_kwargs
 
 
-def test__restart_unit_proxy():
+@pytest.mark.parametrize(
+    "action,method",
+    [
+        ("start", "StartUnit"),
+        ("stop", "StopUnit"),
+        ("restart", "RestartUnit"),
+    ],
+)
+def test__unit_proxy(action, method):
     mock_proxy = unittest.mock.MagicMock()
     with unittest.mock.patch(
         "systemctl_mqtt._dbus.service_manager.get_service_manager_proxy",
         return_value=mock_proxy,
     ):
-        systemctl_mqtt._dbus.service_manager.restart_unit("foo.service")
-        mock_proxy.RestartUnit.assert_called_once_with("foo.service", "replace")
+        # call the wrapper function dynamically
+        getattr(systemctl_mqtt._dbus.service_manager, f"{action}_unit")("foo.service")
+        getattr(mock_proxy, method).assert_called_once_with("foo.service", "replace")
 
 
-def test__restart_unit_method_call():
+@pytest.mark.parametrize(
+    "method",
+    [
+        "StartUnit",
+        "StopUnit",
+        "RestartUnit",
+    ],
+)
+def test__unit_method_call(method):
     with unittest.mock.patch(
         "jeepney.new_method_call", return_value=unittest.mock.MagicMock()
     ) as mock_method_call:
-        service_manager = systemctl_mqtt._dbus.service_manager.ServiceManager()
-        service_manager.RestartUnit("foo.service", "replace")
+        mgr = systemctl_mqtt._dbus.service_manager.ServiceManager()
+        getattr(mgr, method)("foo.service", "replace")
         mock_method_call.assert_called_once_with(
-            remote_obj=service_manager,
-            method="RestartUnit",
+            remote_obj=mgr,
+            method=method,
             signature="ss",
             body=("foo.service", "replace"),
         )
 
 
-def test_restart_unit_with_exception():
+@pytest.mark.parametrize(
+    "action,method",
+    [
+        ("start", "StartUnit"),
+        ("stop", "StopUnit"),
+        ("restart", "RestartUnit"),
+    ],
+)
+def test__unit_with_exception(action, method):
     mock_proxy = unittest.mock.MagicMock()
-    mock_proxy.RestartUnit.side_effect = DBusErrorResponseMock(
+    getattr(mock_proxy, method).side_effect = DBusErrorResponseMock(
         "DBus error", ("mocked",)
     )
+
     with unittest.mock.patch(
         "systemctl_mqtt._dbus.service_manager.get_service_manager_proxy",
         return_value=mock_proxy,
     ), unittest.mock.patch(
         "systemctl_mqtt._dbus.service_manager._LOGGER"
     ) as mock_logger:
-        systemctl_mqtt._dbus.service_manager.restart_unit("example.service")
+        getattr(systemctl_mqtt._dbus.service_manager, f"{action}_unit")(
+            "example.service"
+        )
         mock_logger.error.assert_called_once_with(
-            "Failed to restart unit: %s because %s ", "example.service", "DBus error"
+            f"Failed to {action} unit: %s because %s ",
+            "example.service",
+            "DBus error",
         )

+ 18 - 6
tests/test_mqtt.py

@@ -444,8 +444,12 @@ def test_state_get_system_unit_active_state_mqtt_topic(
 @pytest.mark.filterwarnings("ignore:coroutine '_mqtt_message_loop' was never awaited")
 @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
 @pytest.mark.parametrize("unit_name", ["foo.service", "bar.service"])
-async def test__mqtt_message_loop_trigger_restart(
-    caplog: pytest.LogCaptureFixture, mqtt_topic_prefix: str, unit_name: str
+@pytest.mark.parametrize("action", ["restart", "start", "stop"])
+async def test__mqtt_message_loop_triggers_unit_action(
+    caplog: pytest.LogCaptureFixture,
+    mqtt_topic_prefix: str,
+    unit_name: str,
+    action: str,
 ) -> None:
     state = systemctl_mqtt._State(
         mqtt_topic_prefix=mqtt_topic_prefix,
@@ -455,8 +459,9 @@ async def test__mqtt_message_loop_trigger_restart(
         monitored_system_unit_names=[],
         controlled_system_unit_names=[unit_name],
     )
+
     mqtt_client_mock = unittest.mock.AsyncMock()
-    topic = f"{mqtt_topic_prefix}/unit/system/{unit_name}/restart"
+    topic = f"{mqtt_topic_prefix}/unit/system/{unit_name}/{action}"
     mqtt_client_mock.messages.__aiter__.return_value = [
         aiomqtt.Message(
             topic=topic,
@@ -467,14 +472,21 @@ async def test__mqtt_message_loop_trigger_restart(
             properties=None,
         )
     ]
+
     with unittest.mock.patch(
-        "systemctl_mqtt._dbus.service_manager.restart_unit"
-    ) as trigger_service_restart_mock, caplog.at_level(logging.DEBUG):
+        f"systemctl_mqtt._dbus.service_manager.{action}_unit"
+    ) as trigger_service_mock, caplog.at_level(logging.DEBUG):
         await systemctl_mqtt._mqtt_message_loop(
             state=state, mqtt_client=mqtt_client_mock
         )
+
+    # check subscription
     assert unittest.mock.call(topic) in mqtt_client_mock.subscribe.await_args_list
-    trigger_service_restart_mock.assert_called_once_with(unit_name=unit_name)
+
+    # check correct action method called
+    trigger_service_mock.assert_called_once_with(unit_name=unit_name)
+
+    # check logs (skip "subscribing to ..." chatter)
     assert [
         t for t in caplog.record_tuples[2:] if not t[2].startswith("subscribing to ")
     ] == [

+ 6 - 5
tests/test_state_dbus.py

@@ -252,13 +252,14 @@ async def test_publish_homeassistant_device_config(
             for n in monitored_system_unit_names
         }
         | {
-            f"unit/system/{n}/restart": {
-                "unique_id": f"systemctl-mqtt-{hostname}-unit-system-{n}-restart",
-                "object_id": f"{hostname}_unit_system_{n}_restart",
-                "name": f"{n} restart",
+            f"unit/system/{n}/{action}": {
+                "unique_id": f"systemctl-mqtt-{hostname}-unit-system-{n}-{action}",
+                "object_id": f"{hostname}_unit_system_{n}_{action}",
+                "name": f"{n} {action}",
                 "platform": "button",
-                "command_topic": f"{topic_prefix}/unit/system/{n}/restart",
+                "command_topic": f"{topic_prefix}/unit/system/{n}/{action}",
             }
             for n in controlled_system_unit_names
+            for action in ["restart", "start", "stop"]
         },
     }