7 次代碼提交 33e3a1ca97 ... 042f8c2a6e

作者 SHA1 備註 提交日期
  Fabian Peter Hammerle 042f8c2a6e changelog: fix & reword section of v1.1.0 2 周之前
  Fabian Peter Hammerle 2ec4624508 document release v1.1.0 2 周之前
  dependabot[bot] 9aca56cebf build(deps): bump docker/setup-qemu-action from 3.2.0 to 3.3.0 (#182) 2 周之前
  dependabot[bot] 21ae2d71af build(deps): bump docker/build-push-action from 6.11.0 to 6.12.0 (#185) 2 周之前
  Hanspeter Gosteli 85d150e3e9 Implement restart for units selected via --control-system-unit parameter (#180) 2 周之前
  dependabot[bot] 7e092e0627 build(deps): bump docker/build-push-action from 6.10.0 to 6.11.0 (#183) 3 周之前
  dependabot[bot] c06cf7f42a build(deps-dev): bump pytest-asyncio from 0.25.1 to 0.25.2 (#184) 3 周之前

+ 2 - 2
.github/workflows/container-image.yml

@@ -26,7 +26,7 @@ jobs:
           type=ref,event=pr
           type=sha,format=long
           type=raw,value=latest,enable=false
-    - uses: docker/setup-qemu-action@v3.2.0
+    - uses: docker/setup-qemu-action@v3.3.0
     - uses: docker/login-action@v3
       with:
         registry: ghcr.io
@@ -44,7 +44,7 @@ jobs:
     # https://github.com/marketplace/actions/build-and-push-docker-images
     # > The commit history is not preserved.
     # https://docs.docker.com/engine/reference/commandline/build/#git-repositories
-    - uses: docker/build-push-action@v6.10.0
+    - uses: docker/build-push-action@v6.12.0
       with:
         build-args: | # git history unavailable (see above)
           SETUPTOOLS_SCM_PRETEND_VERSION=0

+ 2 - 0
.gitignore

@@ -4,3 +4,5 @@
 build/
 dist/
 tags
+__pycache__
+

+ 9 - 2
CHANGELOG.md

@@ -6,9 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
+## [1.1.0] - 2025-01-19
+### Added
+- ability to restart system units using `--control_system_unit <unit_name>`
+  ([#180](https://github.com/fphammerle/systemctl-mqtt/pull/180)
+  by Hanspeter Gosteli (hanspeter.gosteli@gmail.com))
+
 ### Documentation
 - added systemd user service config for autostart
-  (https://github.com/fphammerle/systemctl-mqtt/issues/66)
+  ([#66](https://github.com/fphammerle/systemctl-mqtt/issues/66))
 
 ## [1.0.0] - 2025-01-04
 ### Added
@@ -131,7 +137,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 - MQTT message on topic `systemctl/hostname/poweroff`
   schedules a poweroff via systemd's dbus interface (4 seconds delay)
 
-[Unreleased]: https://github.com/fphammerle/systemctl-mqtt/compare/v1.0.0...HEAD
+[Unreleased]: https://github.com/fphammerle/systemctl-mqtt/compare/v1.1.0...HEAD
+[1.1.0]: https://github.com/fphammerle/systemctl-mqtt/compare/v1.0.0...v1.1.0
 [1.0.0]: https://github.com/fphammerle/systemctl-mqtt/compare/v0.5.0...v1.0.0
 [0.5.0]: https://github.com/fphammerle/systemctl-mqtt/compare/v0.4.0...v0.5.0
 [0.4.0]: https://github.com/fphammerle/systemctl-mqtt/compare/v0.3.0...v0.4.0

+ 3 - 3
Pipfile.lock

@@ -308,12 +308,12 @@
         },
         "pytest-asyncio": {
             "hashes": [
-                "sha256:79be8a72384b0c917677e00daa711e07db15259f4d23203c59012bcd989d4aee",
-                "sha256:c84878849ec63ff2ca509423616e071ef9cd8cc93c053aa33b5b8fb70a990671"
+                "sha256:0d0bb693f7b99da304a0634afc0a4b19e49d5e0de2d670f38dc4bfa5727c5075",
+                "sha256:3f8ef9a98f45948ea91a0ed3dc4268b5326c0e7bce73892acc654df4262ad45f"
             ],
             "index": "pypi",
             "markers": "python_version >= '3.9'",
-            "version": "==0.25.1"
+            "version": "==0.25.2"
         },
         "pytest-cov": {
             "hashes": [

+ 11 - 2
README.md

@@ -90,12 +90,19 @@ $ mosquitto_pub -h MQTT_BROKER -t systemctl/hostname/suspend -n
 ### Monitor `ActiveState` of System Units
 
 ```
-$ mosquitto_pub --monitor-system-unit foo.service \
-    --monitor-system-unit bar.service …
+$ systemctl-mqtt --monitor-system-unit foo.service
 ```
 enables reports on topic
 `systemctl/[hostname]/unit/system/[unit_name]/active-state`.
 
+### Restarting of System Units
+
+```
+$ systemctl-mqtt  --control-system-unit <unit_name>
+```
+enables that a system unit can be restarted by a message on topic
+`systemctl/[hostname]/unit/system/[unit_name]/restart`.
+
 ## Home Assistant 🏡
 
 When [MQTT Discovery](https://www.home-assistant.io/integrations/mqtt/#mqtt-discovery)
@@ -107,6 +114,8 @@ added automatically:
 - `button.[hostname]_logind_suspend`
 - `sensor.[hostname]_unit_system_[unit_name]_active_state`
   for `--monitor-system-unit [unit_name]`
+- `button.[hostname]_unit_system_[unit_name]_restart`
+  for `--control-system-unit [unit_name]`
 
 ![homeassistant entities_over_auto_discovery](docs/homeassistant/entities-after-auto-discovery.png)
 

+ 6 - 0
docker-apparmor-profile

@@ -60,4 +60,10 @@ profile systemctl-mqtt flags=(attach_disconnected) {
        interface=org.freedesktop.DBus.Properties
        member=Get
        peer=(label=unconfined),
+  dbus (send)
+       bus=system
+       path=/org/freedesktop/systemd1
+       interface=org.freedesktop.systemd1.Manager
+       member=RestartUnit
+       peer=(label=unconfined),
 }

+ 47 - 0
systemctl_mqtt/__init__.py

@@ -64,6 +64,7 @@ class _State:
         homeassistant_discovery_object_id: str,
         poweroff_delay: datetime.timedelta,
         monitored_system_unit_names: typing.List[str],
+        controlled_system_unit_names: typing.List[str],
     ) -> None:
         self._mqtt_topic_prefix = mqtt_topic_prefix
         self._homeassistant_discovery_prefix = homeassistant_discovery_prefix
@@ -75,6 +76,7 @@ class _State:
         self._shutdown_lock_mutex = threading.Lock()
         self.poweroff_delay = poweroff_delay
         self._monitored_system_unit_names = monitored_system_unit_names
+        self._controlled_system_unit_names = controlled_system_unit_names
 
     @property
     def mqtt_topic_prefix(self) -> str:
@@ -91,10 +93,17 @@ class _State:
     def get_system_unit_active_state_mqtt_topic(self, *, unit_name: str) -> str:
         return self._mqtt_topic_prefix + "/unit/system/" + unit_name + "/active-state"
 
+    def get_system_unit_restart_mqtt_topic(self, *, unit_name: str) -> str:
+        return self._mqtt_topic_prefix + "/unit/system/" + unit_name + "/restart"
+
     @property
     def monitored_system_unit_names(self) -> typing.List[str]:
         return self._monitored_system_unit_names
 
+    @property
+    def controlled_system_unit_names(self) -> typing.List[str]:
+        return self._controlled_system_unit_names
+
     @property
     def shutdown_lock_acquired(self) -> bool:
         return self._shutdown_lock is not None
@@ -222,6 +231,16 @@ class _State:
                     unit_name=unit_name
                 ),
             }
+        for unit_name in self._controlled_system_unit_names:
+            config["components"]["unit/system/" + unit_name + "/restart"] = {  # type: ignore
+                "unique_id": f"{unique_id_prefix}-unit-system-{unit_name}-restart",
+                "object_id": f"{hostname}_unit_system_{unit_name}_restart",
+                "name": f"{unit_name} restart",
+                "platform": "button",
+                "command_topic": self.get_system_unit_restart_mqtt_topic(
+                    unit_name=unit_name
+                ),
+            }
         _LOGGER.debug("publishing home assistant config on %s", discovery_topic)
         await mqtt_client.publish(
             topic=discovery_topic, payload=json.dumps(config), retain=False
@@ -246,6 +265,15 @@ class _MQTTActionSchedulePoweroff(_MQTTAction):
         )
 
 
+class _MQTTActionRestartUnit(_MQTTAction):
+    # pylint: disable=protected-access,too-few-public-methods
+    def __init__(self, unit_name: str):
+        self._unit_name = unit_name
+
+    def trigger(self, state: _State) -> None:
+        systemctl_mqtt._dbus.service_manager.restart_unit(unit_name=self._unit_name)
+
+
 class _MQTTActionLockAllSessions(_MQTTAction):
     # pylint: disable=too-few-public-methods
     def trigger(self, state: _State) -> None:
@@ -274,6 +302,14 @@ async def _mqtt_message_loop(*, state: _State, mqtt_client: aiomqtt.Client) -> N
         _LOGGER.info("subscribing to %s", topic)
         await mqtt_client.subscribe(topic)
         action_by_topic[topic] = action
+
+    for unit_name in state.controlled_system_unit_names:
+        topic = state.mqtt_topic_prefix + "/unit/system/" + unit_name + "/restart"
+        _LOGGER.info("subscribing to %s", topic)
+        await mqtt_client.subscribe(topic)
+        action = _MQTTActionRestartUnit(unit_name=unit_name)
+        action_by_topic[topic] = action
+
     async for message in mqtt_client.messages:
         if message.retain:
             _LOGGER.info("ignoring retained message on topic %r", message.topic.value)
@@ -415,6 +451,7 @@ async def _run(  # pylint: disable=too-many-arguments
     homeassistant_discovery_object_id: str,
     poweroff_delay: datetime.timedelta,
     monitored_system_unit_names: typing.List[str],
+    controlled_system_unit_names: typing.List[str],
     mqtt_disable_tls: bool = False,
 ) -> None:
     state = _State(
@@ -423,6 +460,7 @@ async def _run(  # pylint: disable=too-many-arguments
         homeassistant_discovery_object_id=homeassistant_discovery_object_id,
         poweroff_delay=poweroff_delay,
         monitored_system_unit_names=monitored_system_unit_names,
+        controlled_system_unit_names=controlled_system_unit_names,
     )
     _LOGGER.info(
         "connecting to MQTT broker %s:%d (TLS %s)",
@@ -537,6 +575,14 @@ def _main() -> None:
         action="append",
         help="e.g. --monitor-system-unit ssh.service --monitor-system-unit custom.service",
     )
+    argparser.add_argument(
+        "--control-system-unit",
+        type=str,
+        metavar="UNIT_NAME",
+        dest="controlled_system_unit_names",
+        action="append",
+        help="e.g. --control-system-unit ansible-pull.service --control-system-unit custom.service",
+    )
     args = argparser.parse_args()
     logging.root.setLevel(_ARGUMENT_LOG_LEVEL_MAPPING[args.log_level])
     if args.mqtt_port:
@@ -578,5 +624,6 @@ def _main() -> None:
             homeassistant_discovery_object_id=args.homeassistant_discovery_object_id,
             poweroff_delay=datetime.timedelta(seconds=args.poweroff_delay_seconds),
             monitored_system_unit_names=args.monitored_system_unit_names or [],
+            controlled_system_unit_names=args.controlled_system_unit_names or [],
         )
     )

+ 35 - 1
systemctl_mqtt/_dbus/service_manager.py

@@ -15,10 +15,12 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
+import logging
 import jeepney
-
 import systemctl_mqtt._dbus
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class ServiceManager(jeepney.MessageGenerator):
     """
@@ -41,6 +43,17 @@ class ServiceManager(jeepney.MessageGenerator):
             remote_obj=self, method="GetUnit", signature="s", body=(name,)
         )
 
+    def RestartUnit(self, name: str, mode: str) -> jeepney.low_level.Message:
+        return jeepney.new_method_call(
+            remote_obj=self,
+            method="RestartUnit",
+            signature="ss",
+            body=(
+                name,
+                mode,
+            ),
+        )
+
 
 class Unit(systemctl_mqtt._dbus.Properties):  # pylint: disable=protected-access
     """
@@ -55,3 +68,24 @@ class Unit(systemctl_mqtt._dbus.Properties):  # pylint: disable=protected-access
         super().__init__(object_path=object_path, bus_name="org.freedesktop.systemd1")
 
     # pylint: disable=invalid-name
+
+
+def restart_unit(unit_name: str):
+    proxy = get_service_manager_proxy()
+    try:
+        proxy.RestartUnit(unit_name, "replace")
+        _LOGGER.debug("Restarting unit: %s", unit_name)
+    # pylint: disable=broad-exception-caught
+    except jeepney.wrappers.DBusErrorResponse as exc:
+        _LOGGER.error("Failed to restart unit: %s because %s ", unit_name, exc.name)
+
+
+def get_service_manager_proxy() -> jeepney.io.blocking.Proxy:
+    # https://jeepney.readthedocs.io/en/latest/integrate.html
+    # https://gitlab.com/takluyver/jeepney/-/blob/master/examples/aio_notify.py
+    return jeepney.io.blocking.Proxy(
+        msggen=ServiceManager(),
+        connection=jeepney.io.blocking.open_dbus_connection(
+            bus="SYSTEM",
+        ),
+    )

+ 49 - 0
tests/dbus/message-generators/test_service_manager.py

@@ -15,6 +15,7 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
+import typing
 import unittest.mock
 
 import pytest
@@ -26,6 +27,13 @@ import systemctl_mqtt
 # pylint: disable=protected-access
 
 
+class DBusErrorResponseMock(jeepney.wrappers.DBusErrorResponse):
+    # pylint: disable=missing-class-docstring,super-init-not-called
+    def __init__(self, name: str, data: typing.Any):
+        self.name = name
+        self.data = data
+
+
 @pytest.mark.asyncio
 async def test__get_unit_path() -> None:
     router_mock = unittest.mock.AsyncMock()
@@ -55,3 +63,44 @@ async def test__get_unit_path() -> None:
     }
     assert msg.body == ("ssh.service",)
     assert not send_kwargs
+
+
+def test__restart_unit_proxy():
+    mock_proxy = unittest.mock.MagicMock()
+    with unittest.mock.patch(
+        "systemctl_mqtt._dbus.service_manager.get_service_manager_proxy",
+        return_value=mock_proxy,
+    ):
+        systemctl_mqtt._dbus.service_manager.restart_unit("foo.service")
+        mock_proxy.RestartUnit.assert_called_once_with("foo.service", "replace")
+
+
+def test__restart_unit_method_call():
+    with unittest.mock.patch(
+        "jeepney.new_method_call", return_value=unittest.mock.MagicMock()
+    ) as mock_method_call:
+        service_manager = systemctl_mqtt._dbus.service_manager.ServiceManager()
+        service_manager.RestartUnit("foo.service", "replace")
+        mock_method_call.assert_called_once_with(
+            remote_obj=service_manager,
+            method="RestartUnit",
+            signature="ss",
+            body=("foo.service", "replace"),
+        )
+
+
+def test_restart_unit_with_exception():
+    mock_proxy = unittest.mock.MagicMock()
+    mock_proxy.RestartUnit.side_effect = DBusErrorResponseMock(
+        "DBus error", ("mocked",)
+    )
+    with unittest.mock.patch(
+        "systemctl_mqtt._dbus.service_manager.get_service_manager_proxy",
+        return_value=mock_proxy,
+    ), unittest.mock.patch(
+        "systemctl_mqtt._dbus.service_manager._LOGGER"
+    ) as mock_logger:
+        systemctl_mqtt._dbus.service_manager.restart_unit("example.service")
+        mock_logger.error.assert_called_once_with(
+            "Failed to restart unit: %s because %s ", "example.service", "DBus error"
+        )

+ 2 - 0
tests/test_action.py

@@ -25,6 +25,7 @@ def test_poweroff_trigger(delay):
                 homeassistant_discovery_object_id="node",
                 poweroff_delay=delay,
                 monitored_system_unit_names=[],
+                controlled_system_unit_names=[],
             )
         )
     schedule_shutdown_mock.assert_called_once_with(action="poweroff", delay=delay)
@@ -47,6 +48,7 @@ def test_mqtt_topic_suffix_action_mapping_poweroff(topic_suffix, expected_action
                 homeassistant_discovery_object_id="node",
                 poweroff_delay=datetime.timedelta(),
                 monitored_system_unit_names=[],
+                controlled_system_unit_names=[],
             )
         )
     login_manager_mock.ScheduleShutdown.assert_called_once()

+ 2 - 0
tests/test_cli.py

@@ -183,6 +183,7 @@ def test__main(
         homeassistant_discovery_object_id="systemctl-mqtt-hostname",
         poweroff_delay=datetime.timedelta(seconds=4),
         monitored_system_unit_names=[],
+        controlled_system_unit_names=[],
     )
 
 
@@ -231,6 +232,7 @@ def test__main_password_file(tmpdir, password_file_content, expected_password):
         homeassistant_discovery_object_id="systemctl-mqtt-hostname",
         poweroff_delay=datetime.timedelta(seconds=4),
         monitored_system_unit_names=[],
+        controlled_system_unit_names=[],
     )
 
 

+ 7 - 0
tests/test_dbus.py

@@ -40,6 +40,12 @@ def test_get_login_manager_proxy():
     assert login_manager.CanPowerOff() in {("yes",), ("challenge",)}
 
 
+def test_get_service_manager_proxy():
+    service_manager = systemctl_mqtt._dbus.service_manager.get_service_manager_proxy()
+    assert isinstance(service_manager, jeepney.io.blocking.Proxy)
+    assert service_manager._msggen.interface == "org.freedesktop.systemd1.Manager"
+
+
 def test__log_shutdown_inhibitors_some(caplog):
     login_manager = unittest.mock.MagicMock()
     login_manager.ListInhibitors.return_value = (
@@ -386,6 +392,7 @@ async def test__dbus_signal_loop_unit() -> None:
         homeassistant_discovery_object_id="unused",
         poweroff_delay=datetime.timedelta(),
         monitored_system_unit_names=[],
+        controlled_system_unit_names=[],
     )
     mqtt_client_mock = unittest.mock.AsyncMock()
     dbus_router_mock = unittest.mock.AsyncMock()

+ 60 - 4
tests/test_mqtt.py

@@ -32,7 +32,7 @@ import systemctl_mqtt
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
-@pytest.mark.parametrize("mqtt_port", [1833])
+@pytest.mark.parametrize("mqtt_port", [1883])
 @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
 @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
 @pytest.mark.parametrize("homeassistant_discovery_object_id", ["host", "node"])
@@ -67,6 +67,7 @@ async def test__run(
             homeassistant_discovery_object_id=homeassistant_discovery_object_id,
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=[],
+            controlled_system_unit_names=[],
         )
     assert caplog.records[0].levelno == logging.INFO
     assert caplog.records[0].message == (
@@ -171,6 +172,7 @@ async def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
             homeassistant_discovery_object_id="host",
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=[],
+            controlled_system_unit_names=[],
         )
     mqtt_client_class_mock.assert_called_once()
     _, mqtt_client_init_kwargs = mqtt_client_class_mock.call_args
@@ -198,7 +200,7 @@ async def test__run_tls_default():
     ) as dbus_signal_loop_mock:
         await systemctl_mqtt._run(
             mqtt_host="mqtt-broker.local",
-            mqtt_port=1833,
+            mqtt_port=1883,
             # mqtt_disable_tls default,
             mqtt_username=None,
             mqtt_password=None,
@@ -207,6 +209,7 @@ async def test__run_tls_default():
             homeassistant_discovery_object_id="host",
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=[],
+            controlled_system_unit_names=[],
         )
     mqtt_client_class_mock.assert_called_once()
     # enabled by default
@@ -218,7 +221,7 @@ async def test__run_tls_default():
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
-@pytest.mark.parametrize("mqtt_port", [1833])
+@pytest.mark.parametrize("mqtt_port", [1883])
 @pytest.mark.parametrize("mqtt_username", ["me"])
 @pytest.mark.parametrize("mqtt_password", [None, "secret"])
 @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
@@ -240,6 +243,7 @@ async def test__run_authentication(
             homeassistant_discovery_object_id="node-id",
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=[],
+            controlled_system_unit_names=[],
         )
     mqtt_client_class_mock.assert_called_once()
     _, mqtt_client_init_kwargs = mqtt_client_class_mock.call_args
@@ -253,7 +257,7 @@ async def test__run_authentication(
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
-@pytest.mark.parametrize("mqtt_port", [1833])
+@pytest.mark.parametrize("mqtt_port", [1883])
 @pytest.mark.parametrize("mqtt_password", ["secret"])
 async def test__run_authentication_missing_username(
     mqtt_host: str, mqtt_port: int, mqtt_password: str
@@ -272,6 +276,7 @@ async def test__run_authentication_missing_username(
                 homeassistant_discovery_object_id="node-id",
                 poweroff_delay=datetime.timedelta(),
                 monitored_system_unit_names=[],
+                controlled_system_unit_names=[],
             )
     dbus_signal_loop_mock.assert_not_called()
 
@@ -301,6 +306,7 @@ async def test__run_sigint(mqtt_topic_prefix: str):
                 homeassistant_discovery_object_id="host",
                 poweroff_delay=datetime.timedelta(),
                 monitored_system_unit_names=[],
+                controlled_system_unit_names=[],
             )
     async with mqtt_client_class_mock() as mqtt_client_mock:
         pass
@@ -334,6 +340,7 @@ async def test__mqtt_message_loop_trigger_poweroff(
         homeassistant_discovery_object_id="whatever",
         poweroff_delay=datetime.timedelta(seconds=21),
         monitored_system_unit_names=[],
+        controlled_system_unit_names=[],
     )
     mqtt_client_mock = unittest.mock.AsyncMock()
     mqtt_client_mock.messages.__aiter__.return_value = [
@@ -382,6 +389,7 @@ async def test__mqtt_message_loop_retained(
         homeassistant_discovery_object_id="whatever",
         poweroff_delay=datetime.timedelta(seconds=21),
         monitored_system_unit_names=[],
+        controlled_system_unit_names=[],
     )
     mqtt_client_mock = unittest.mock.AsyncMock()
     mqtt_client_mock.messages.__aiter__.return_value = [
@@ -423,8 +431,56 @@ def test_state_get_system_unit_active_state_mqtt_topic(
         homeassistant_discovery_object_id="whatever",
         poweroff_delay=datetime.timedelta(seconds=21),
         monitored_system_unit_names=[],
+        controlled_system_unit_names=[],
     )
     assert (
         state.get_system_unit_active_state_mqtt_topic(unit_name=unit_name)
         == f"{mqtt_topic_prefix}/unit/system/{unit_name}/active-state"
     )
+
+
+@pytest.mark.asyncio
+@pytest.mark.filterwarnings("ignore:coroutine '_dbus_signal_loop' was never awaited")
+@pytest.mark.filterwarnings("ignore:coroutine '_mqtt_message_loop' was never awaited")
+@pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
+@pytest.mark.parametrize("unit_name", ["foo.service", "bar.service"])
+async def test__mqtt_message_loop_trigger_restart(
+    caplog: pytest.LogCaptureFixture, mqtt_topic_prefix: str, unit_name: str
+) -> None:
+    state = systemctl_mqtt._State(
+        mqtt_topic_prefix=mqtt_topic_prefix,
+        homeassistant_discovery_prefix="homeassistant",
+        homeassistant_discovery_object_id="whatever",
+        poweroff_delay=datetime.timedelta(seconds=21),
+        monitored_system_unit_names=[],
+        controlled_system_unit_names=[unit_name],
+    )
+    mqtt_client_mock = unittest.mock.AsyncMock()
+    topic = f"{mqtt_topic_prefix}/unit/system/{unit_name}/restart"
+    mqtt_client_mock.messages.__aiter__.return_value = [
+        aiomqtt.Message(
+            topic=topic,
+            payload=b"some-payload",
+            qos=0,
+            retain=False,
+            mid=42 // 2,
+            properties=None,
+        )
+    ]
+    with unittest.mock.patch(
+        "systemctl_mqtt._dbus.service_manager.restart_unit"
+    ) as trigger_service_restart_mock, caplog.at_level(logging.DEBUG):
+        await systemctl_mqtt._mqtt_message_loop(
+            state=state, mqtt_client=mqtt_client_mock
+        )
+    assert unittest.mock.call(topic) in mqtt_client_mock.subscribe.await_args_list
+    trigger_service_restart_mock.assert_called_once_with(unit_name=unit_name)
+    assert [
+        t for t in caplog.record_tuples[2:] if not t[2].startswith("subscribing to ")
+    ] == [
+        (
+            "systemctl_mqtt",
+            logging.DEBUG,
+            f"received message on topic '{topic}': b'some-payload'",
+        ),
+    ]

+ 26 - 1
tests/test_state_dbus.py

@@ -41,6 +41,7 @@ def test_shutdown_lock():
             homeassistant_discovery_object_id=None,
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=[],
+            controlled_system_unit_names=[],
         )
         get_login_manager_mock.return_value.Inhibit.return_value = (lock_fd,)
         state.acquire_shutdown_lock()
@@ -68,6 +69,7 @@ async def test_preparing_for_shutdown_handler(active: bool) -> None:
             homeassistant_discovery_object_id="obj",
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=[],
+            controlled_system_unit_names=[],
         )
     mqtt_client_mock = unittest.mock.MagicMock()
     with unittest.mock.patch.object(
@@ -104,6 +106,7 @@ async def test_publish_preparing_for_shutdown(active: bool) -> None:
             homeassistant_discovery_object_id="obj",
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=[],
+            controlled_system_unit_names=[],
         )
     assert state._login_manager == login_manager_mock
     mqtt_client_mock = unittest.mock.AsyncMock()
@@ -137,6 +140,7 @@ async def test_publish_preparing_for_shutdown_get_fail(caplog):
             homeassistant_discovery_object_id=None,
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=[],
+            controlled_system_unit_names=[],
         )
     mqtt_client_mock = unittest.mock.MagicMock()
     await state.publish_preparing_for_shutdown(mqtt_client=None)
@@ -155,14 +159,23 @@ async def test_publish_preparing_for_shutdown_get_fail(caplog):
 @pytest.mark.parametrize("object_id", ["raspberrypi", "debian21"])
 @pytest.mark.parametrize("hostname", ["hostname", "host-name"])
 @pytest.mark.parametrize(
-    "monitored_system_unit_names", [[], ["foo.service", "bar.service"]]
+    ("monitored_system_unit_names", "controlled_system_unit_names"),
+    [
+        ([], []),
+        (
+            ["foo.service", "bar.service"],
+            ["foo-control.service", "bar-control.service"],
+        ),
+    ],
 )
 async def test_publish_homeassistant_device_config(
+    # pylint: disable=too-many-arguments,too-many-positional-arguments
     topic_prefix: str,
     discovery_prefix: str,
     object_id: str,
     hostname: str,
     monitored_system_unit_names: typing.List[str],
+    controlled_system_unit_names: typing.List[str],
 ) -> None:
     with unittest.mock.patch("jeepney.io.blocking.open_dbus_connection"):
         state = systemctl_mqtt._State(
@@ -171,8 +184,10 @@ async def test_publish_homeassistant_device_config(
             homeassistant_discovery_object_id=object_id,
             poweroff_delay=datetime.timedelta(),
             monitored_system_unit_names=monitored_system_unit_names,
+            controlled_system_unit_names=controlled_system_unit_names,
         )
     assert state.monitored_system_unit_names == monitored_system_unit_names
+    assert state.controlled_system_unit_names == controlled_system_unit_names
     mqtt_client = unittest.mock.AsyncMock()
     with unittest.mock.patch(
         "systemctl_mqtt._utils.get_hostname", return_value=hostname
@@ -235,5 +250,15 @@ async def test_publish_homeassistant_device_config(
                 "state_topic": f"{topic_prefix}/unit/system/{n}/active-state",
             }
             for n in monitored_system_unit_names
+        }
+        | {
+            f"unit/system/{n}/restart": {
+                "unique_id": f"systemctl-mqtt-{hostname}-unit-system-{n}-restart",
+                "object_id": f"{hostname}_unit_system_{n}_restart",
+                "name": f"{n} restart",
+                "platform": "button",
+                "command_topic": f"{topic_prefix}/unit/system/{n}/restart",
+            }
+            for n in controlled_system_unit_names
         },
     }