Browse Source

acquire shutdown inhibitor lock to reserve time to send mqtt msg before shutdown

Fabian Peter Hammerle 3 years ago
parent
commit
45fe3cfc23
3 changed files with 46 additions and 32 deletions
  1. 23 3
      systemctl_mqtt/__init__.py
  2. 23 1
      tests/test_mqtt.py
  3. 0 28
      tests/test_settings.py

+ 23 - 3
systemctl_mqtt/__init__.py

@@ -19,12 +19,15 @@ import argparse
 import datetime
 import functools
 import logging
+import os
 import pathlib
 import socket
+import threading
 import typing
 
 import dbus
 import dbus.mainloop.glib
+import dbus.types
 
 # black keeps inserting a blank line above
 # https://pygobject.readthedocs.io/en/latest/getting_started.html#ubuntu-logo-ubuntu-debian-logo-debian
@@ -110,16 +113,32 @@ def _schedule_shutdown(action: str) -> None:
 
 
 class _Settings:
-
-    # pylint: disable=too-few-public-methods
-
     def __init__(self, mqtt_topic_prefix: str) -> None:
         self._mqtt_topic_prefix = mqtt_topic_prefix
+        self._shutdown_lock = None  # type: typing.Optional[dbus.types.UnixFd]
+        self._shutdown_lock_mutex = threading.Lock()
 
     @property
     def mqtt_topic_prefix(self) -> str:
         return self._mqtt_topic_prefix
 
+    def acquire_shutdown_lock(self) -> None:
+        with self._shutdown_lock_mutex:
+            assert self._shutdown_lock is None
+            # https://www.freedesktop.org/wiki/Software/systemd/inhibit/
+            self._shutdown_lock = _get_login_manager().Inhibit(
+                "shutdown", "systemctl-mqtt", "Report shutdown via MQTT", "delay",
+            )
+            _LOGGER.debug("acquired shutdown inhibitor lock")
+
+    def release_shutdown_lock(self) -> None:
+        with self._shutdown_lock_mutex:
+            if self._shutdown_lock:
+                # https://dbus.freedesktop.org/doc/dbus-python/dbus.types.html#dbus.types.UnixFd.take
+                os.close(self._shutdown_lock.take())
+                _LOGGER.debug("released shutdown inhibitor lock")
+                self._shutdown_lock = None
+
 
 class _MQTTAction:
 
@@ -165,6 +184,7 @@ def _mqtt_on_connect(
     assert return_code == 0, return_code  # connection accepted
     mqtt_broker_host, mqtt_broker_port = mqtt_client.socket().getpeername()
     _LOGGER.debug("connected to MQTT broker %s:%d", mqtt_broker_host, mqtt_broker_port)
+    settings.acquire_shutdown_lock()
     for topic_suffix, action in _MQTT_TOPIC_SUFFIX_ACTION_MAPPING.items():
         topic = settings.mqtt_topic_prefix + "/" + topic_suffix
         _LOGGER.info("subscribing to %s", topic)

+ 23 - 1
tests/test_mqtt.py

@@ -77,8 +77,11 @@ def test__run(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
     mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
     with unittest.mock.patch(
         "paho.mqtt.client.Client.subscribe"
-    ) as mqtt_subscribe_mock:
+    ) as mqtt_subscribe_mock, unittest.mock.patch.object(
+        mqtt_client._userdata, "acquire_shutdown_lock"
+    ) as acquire_shutdown_lock_mock:
         mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
+    acquire_shutdown_lock_mock.assert_called_once_with()
     mqtt_subscribe_mock.assert_called_once_with(mqtt_topic_prefix + "/poweroff")
     assert mqtt_client.on_message is None
     assert (  # pylint: disable=comparison-with-callable
@@ -222,3 +225,22 @@ def test_mqtt_message_callback_poweroff_retained(
     )
     assert caplog.records[1].levelno == logging.INFO
     assert caplog.records[1].message == "ignoring retained message"
+
+
+def test_shutdown_lock():
+    settings = systemctl_mqtt._Settings(mqtt_topic_prefix="any")
+    lock_fd = unittest.mock.MagicMock()
+    with unittest.mock.patch(
+        "systemctl_mqtt._get_login_manager"
+    ) as get_login_manager_mock:
+        get_login_manager_mock.return_value.Inhibit.return_value = lock_fd
+        settings.acquire_shutdown_lock()
+    get_login_manager_mock.return_value.Inhibit.assert_called_once_with(
+        "shutdown", "systemctl-mqtt", "Report shutdown via MQTT", "delay",
+    )
+    assert settings._shutdown_lock == lock_fd
+    # https://dbus.freedesktop.org/doc/dbus-python/dbus.types.html#dbus.types.UnixFd.take
+    lock_fd.take.return_value = "fdnum"
+    with unittest.mock.patch("os.close") as close_mock:
+        settings.release_shutdown_lock()
+    close_mock.assert_called_once_with("fdnum")

+ 0 - 28
tests/test_settings.py

@@ -1,28 +0,0 @@
-# systemctl-mqtt - MQTT client triggering shutdown on systemd-based systems
-#
-# Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <https://www.gnu.org/licenses/>.
-
-import pytest
-
-import systemctl_mqtt
-
-# pylint: disable=protected-access
-
-
-@pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
-def test_mqtt_topic_action_mapping(mqtt_topic_prefix):
-    settings = systemctl_mqtt._Settings(mqtt_topic_prefix=mqtt_topic_prefix)
-    assert settings.mqtt_topic_prefix == mqtt_topic_prefix