1
0

test_mqtt.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import datetime
  18. import logging
  19. import ssl
  20. import unittest.mock
  21. import aiomqtt
  22. import jeepney.fds
  23. import jeepney.low_level
  24. import pytest
  25. import systemctl_mqtt
  26. # pylint: disable=protected-access,too-many-positional-arguments
  27. @pytest.mark.asyncio
  28. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  29. @pytest.mark.parametrize("mqtt_port", [1833])
  30. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  31. @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
  32. @pytest.mark.parametrize("homeassistant_discovery_object_id", ["host", "node"])
  33. async def test__run(
  34. caplog,
  35. mqtt_host,
  36. mqtt_port,
  37. mqtt_topic_prefix,
  38. homeassistant_discovery_prefix,
  39. homeassistant_discovery_object_id,
  40. ):
  41. # pylint: disable=too-many-locals,too-many-arguments
  42. caplog.set_level(logging.DEBUG)
  43. login_manager_mock = unittest.mock.MagicMock()
  44. with unittest.mock.patch(
  45. "aiomqtt.Client", autospec=False
  46. ) as mqtt_client_class_mock, unittest.mock.patch(
  47. "systemctl_mqtt._dbus.login_manager.get_login_manager_proxy",
  48. return_value=login_manager_mock,
  49. ), unittest.mock.patch(
  50. "systemctl_mqtt._dbus_signal_loop"
  51. ) as dbus_signal_loop_mock:
  52. login_manager_mock.Inhibit.return_value = (jeepney.fds.FileDescriptor(-1),)
  53. login_manager_mock.Get.return_value = (("b", False),)
  54. await systemctl_mqtt._run(
  55. mqtt_host=mqtt_host,
  56. mqtt_port=mqtt_port,
  57. mqtt_username=None,
  58. mqtt_password=None,
  59. mqtt_topic_prefix=mqtt_topic_prefix,
  60. homeassistant_discovery_prefix=homeassistant_discovery_prefix,
  61. homeassistant_discovery_object_id=homeassistant_discovery_object_id,
  62. poweroff_delay=datetime.timedelta(),
  63. monitored_system_unit_names=[],
  64. )
  65. assert caplog.records[0].levelno == logging.INFO
  66. assert caplog.records[0].message == (
  67. f"connecting to MQTT broker {mqtt_host}:{mqtt_port} (TLS enabled)"
  68. )
  69. mqtt_client_class_mock.assert_called_once()
  70. _, mqtt_client_init_kwargs = mqtt_client_class_mock.call_args
  71. assert mqtt_client_init_kwargs.pop("hostname") == mqtt_host
  72. assert mqtt_client_init_kwargs.pop("port") == mqtt_port
  73. assert isinstance(mqtt_client_init_kwargs.pop("tls_context"), ssl.SSLContext)
  74. assert mqtt_client_init_kwargs.pop("username") is None
  75. assert mqtt_client_init_kwargs.pop("password") is None
  76. assert mqtt_client_init_kwargs.pop("will") == aiomqtt.Will(
  77. topic=mqtt_topic_prefix + "/status",
  78. payload="offline",
  79. qos=0,
  80. retain=True,
  81. properties=None,
  82. )
  83. assert not mqtt_client_init_kwargs
  84. login_manager_mock.Inhibit.assert_called_once_with(
  85. what="shutdown",
  86. who="systemctl-mqtt",
  87. why="Report shutdown via MQTT",
  88. mode="delay",
  89. )
  90. login_manager_mock.Get.assert_called_once_with("PreparingForShutdown")
  91. async with mqtt_client_class_mock() as mqtt_client_mock:
  92. pass
  93. assert mqtt_client_mock.publish.call_count == 4
  94. assert (
  95. mqtt_client_mock.publish.call_args_list[0][1]["topic"]
  96. == f"{homeassistant_discovery_prefix}/device/{homeassistant_discovery_object_id}/config"
  97. )
  98. assert mqtt_client_mock.publish.call_args_list[1] == unittest.mock.call(
  99. topic=mqtt_topic_prefix + "/preparing-for-shutdown",
  100. payload="false",
  101. retain=False,
  102. )
  103. assert mqtt_client_mock.publish.call_args_list[2][1] == {
  104. "topic": mqtt_topic_prefix + "/status",
  105. "payload": "online",
  106. "retain": True,
  107. }
  108. assert mqtt_client_mock.publish.call_args_list[3][1] == {
  109. "topic": mqtt_topic_prefix + "/status",
  110. "payload": "offline",
  111. "retain": True,
  112. }
  113. assert sorted(mqtt_client_mock.subscribe.call_args_list) == [
  114. unittest.mock.call(mqtt_topic_prefix + "/lock-all-sessions"),
  115. unittest.mock.call(mqtt_topic_prefix + "/poweroff"),
  116. unittest.mock.call(mqtt_topic_prefix + "/suspend"),
  117. ]
  118. assert caplog.records[1].levelno == logging.DEBUG
  119. assert (
  120. caplog.records[1].message == f"connected to MQTT broker {mqtt_host}:{mqtt_port}"
  121. )
  122. assert caplog.records[2].levelno == logging.DEBUG
  123. assert caplog.records[2].message == "acquired shutdown inhibitor lock"
  124. assert caplog.records[3].levelno == logging.DEBUG
  125. assert (
  126. caplog.records[3].message
  127. == "publishing home assistant config on "
  128. + homeassistant_discovery_prefix
  129. + "/device/"
  130. + homeassistant_discovery_object_id
  131. + "/config"
  132. )
  133. assert caplog.records[4].levelno == logging.INFO
  134. assert (
  135. caplog.records[4].message
  136. == f"publishing 'false' on {mqtt_topic_prefix}/preparing-for-shutdown"
  137. )
  138. assert all(r.levelno == logging.INFO for r in caplog.records[5::2])
  139. assert {r.message for r in caplog.records[5:]} == {
  140. f"subscribing to {mqtt_topic_prefix}/{s}"
  141. for s in ("poweroff", "lock-all-sessions", "suspend")
  142. }
  143. dbus_signal_loop_mock.assert_awaited_once()
  144. @pytest.mark.asyncio
  145. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  146. @pytest.mark.parametrize("mqtt_port", [1833])
  147. @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
  148. async def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
  149. caplog.set_level(logging.INFO)
  150. with unittest.mock.patch(
  151. "aiomqtt.Client"
  152. ) as mqtt_client_class_mock, unittest.mock.patch(
  153. "systemctl_mqtt._dbus_signal_loop"
  154. ) as dbus_signal_loop_mock:
  155. await systemctl_mqtt._run(
  156. mqtt_host=mqtt_host,
  157. mqtt_port=mqtt_port,
  158. mqtt_disable_tls=mqtt_disable_tls,
  159. mqtt_username=None,
  160. mqtt_password=None,
  161. mqtt_topic_prefix="systemctl/hosts",
  162. homeassistant_discovery_prefix="homeassistant",
  163. homeassistant_discovery_object_id="host",
  164. poweroff_delay=datetime.timedelta(),
  165. monitored_system_unit_names=[],
  166. )
  167. mqtt_client_class_mock.assert_called_once()
  168. _, mqtt_client_init_kwargs = mqtt_client_class_mock.call_args
  169. assert mqtt_client_init_kwargs.pop("hostname") == mqtt_host
  170. assert mqtt_client_init_kwargs.pop("port") == mqtt_port
  171. if mqtt_disable_tls:
  172. assert mqtt_client_init_kwargs.pop("tls_context") is None
  173. else:
  174. assert isinstance(mqtt_client_init_kwargs.pop("tls_context"), ssl.SSLContext)
  175. assert set(mqtt_client_init_kwargs.keys()) == {"username", "password", "will"}
  176. assert caplog.records[0].levelno == logging.INFO
  177. assert caplog.records[0].message == (
  178. f"connecting to MQTT broker {mqtt_host}:{mqtt_port}"
  179. f" (TLS {'disabled' if mqtt_disable_tls else 'enabled'})"
  180. )
  181. dbus_signal_loop_mock.assert_awaited_once()
  182. @pytest.mark.asyncio
  183. async def test__run_tls_default():
  184. with unittest.mock.patch(
  185. "aiomqtt.Client"
  186. ) as mqtt_client_class_mock, unittest.mock.patch(
  187. "systemctl_mqtt._dbus_signal_loop"
  188. ) as dbus_signal_loop_mock:
  189. await systemctl_mqtt._run(
  190. mqtt_host="mqtt-broker.local",
  191. mqtt_port=1833,
  192. # mqtt_disable_tls default,
  193. mqtt_username=None,
  194. mqtt_password=None,
  195. mqtt_topic_prefix="systemctl/hosts",
  196. homeassistant_discovery_prefix="homeassistant",
  197. homeassistant_discovery_object_id="host",
  198. poweroff_delay=datetime.timedelta(),
  199. monitored_system_unit_names=[],
  200. )
  201. mqtt_client_class_mock.assert_called_once()
  202. # enabled by default
  203. assert isinstance(
  204. mqtt_client_class_mock.call_args[1]["tls_context"], ssl.SSLContext
  205. )
  206. dbus_signal_loop_mock.assert_awaited_once()
  207. @pytest.mark.asyncio
  208. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  209. @pytest.mark.parametrize("mqtt_port", [1833])
  210. @pytest.mark.parametrize("mqtt_username", ["me"])
  211. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  212. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  213. async def test__run_authentication(
  214. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  215. ):
  216. with unittest.mock.patch(
  217. "aiomqtt.Client"
  218. ) as mqtt_client_class_mock, unittest.mock.patch(
  219. "systemctl_mqtt._dbus_signal_loop"
  220. ) as dbus_signal_loop_mock:
  221. await systemctl_mqtt._run(
  222. mqtt_host=mqtt_host,
  223. mqtt_port=mqtt_port,
  224. mqtt_username=mqtt_username,
  225. mqtt_password=mqtt_password,
  226. mqtt_topic_prefix=mqtt_topic_prefix,
  227. homeassistant_discovery_prefix="discovery-prefix",
  228. homeassistant_discovery_object_id="node-id",
  229. poweroff_delay=datetime.timedelta(),
  230. monitored_system_unit_names=[],
  231. )
  232. mqtt_client_class_mock.assert_called_once()
  233. _, mqtt_client_init_kwargs = mqtt_client_class_mock.call_args
  234. assert mqtt_client_init_kwargs["username"] == mqtt_username
  235. if mqtt_password:
  236. assert mqtt_client_init_kwargs["password"] == mqtt_password
  237. else:
  238. assert mqtt_client_init_kwargs["password"] is None
  239. dbus_signal_loop_mock.assert_awaited_once()
  240. @pytest.mark.asyncio
  241. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  242. @pytest.mark.parametrize("mqtt_port", [1833])
  243. @pytest.mark.parametrize("mqtt_password", ["secret"])
  244. async def test__run_authentication_missing_username(
  245. mqtt_host: str, mqtt_port: int, mqtt_password: str
  246. ) -> None:
  247. with unittest.mock.patch("aiomqtt.Client"), unittest.mock.patch(
  248. "systemctl_mqtt._dbus.login_manager.get_login_manager_proxy"
  249. ), unittest.mock.patch("systemctl_mqtt._dbus_signal_loop") as dbus_signal_loop_mock:
  250. with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
  251. await systemctl_mqtt._run(
  252. mqtt_host=mqtt_host,
  253. mqtt_port=mqtt_port,
  254. mqtt_username=None,
  255. mqtt_password=mqtt_password,
  256. mqtt_topic_prefix="prefix",
  257. homeassistant_discovery_prefix="discovery-prefix",
  258. homeassistant_discovery_object_id="node-id",
  259. poweroff_delay=datetime.timedelta(),
  260. monitored_system_unit_names=[],
  261. )
  262. dbus_signal_loop_mock.assert_not_called()
  263. @pytest.mark.asyncio
  264. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  265. async def test__run_sigint(mqtt_topic_prefix: str):
  266. login_manager_mock = unittest.mock.MagicMock()
  267. with unittest.mock.patch(
  268. "aiomqtt.Client", autospec=False
  269. ) as mqtt_client_class_mock, unittest.mock.patch(
  270. "systemctl_mqtt._dbus.login_manager.get_login_manager_proxy",
  271. return_value=login_manager_mock,
  272. ), unittest.mock.patch(
  273. "asyncio.gather", side_effect=KeyboardInterrupt
  274. ):
  275. login_manager_mock.Inhibit.return_value = (jeepney.fds.FileDescriptor(-1),)
  276. login_manager_mock.Get.return_value = (("b", False),)
  277. with pytest.raises(KeyboardInterrupt):
  278. await systemctl_mqtt._run(
  279. mqtt_host="mqtt-broker.local",
  280. mqtt_port=1883,
  281. mqtt_username=None,
  282. mqtt_password=None,
  283. mqtt_topic_prefix=mqtt_topic_prefix,
  284. homeassistant_discovery_prefix="homeassistant",
  285. homeassistant_discovery_object_id="host",
  286. poweroff_delay=datetime.timedelta(),
  287. monitored_system_unit_names=[],
  288. )
  289. async with mqtt_client_class_mock() as mqtt_client_mock:
  290. pass
  291. assert mqtt_client_mock.publish.call_count == 4
  292. assert mqtt_client_mock.publish.call_args_list[0][1]["topic"].endswith("/config")
  293. assert mqtt_client_mock.publish.call_args_list[1][1]["topic"].endswith(
  294. "/preparing-for-shutdown"
  295. )
  296. assert mqtt_client_mock.publish.call_args_list[2][1] == {
  297. "topic": mqtt_topic_prefix + "/status",
  298. "payload": "online",
  299. "retain": True,
  300. }
  301. assert mqtt_client_mock.publish.call_args_list[3][1] == {
  302. "topic": mqtt_topic_prefix + "/status",
  303. "payload": "offline",
  304. "retain": True,
  305. }
  306. @pytest.mark.asyncio
  307. @pytest.mark.filterwarnings("ignore:coroutine '_dbus_signal_loop' was never awaited")
  308. @pytest.mark.filterwarnings("ignore:coroutine '_mqtt_message_loop' was never awaited")
  309. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  310. async def test__mqtt_message_loop_trigger_poweroff(
  311. caplog: pytest.LogCaptureFixture, mqtt_topic_prefix: str
  312. ) -> None:
  313. state = systemctl_mqtt._State(
  314. mqtt_topic_prefix=mqtt_topic_prefix,
  315. homeassistant_discovery_prefix="homeassistant",
  316. homeassistant_discovery_object_id="whatever",
  317. poweroff_delay=datetime.timedelta(seconds=21),
  318. monitored_system_unit_names=[],
  319. )
  320. mqtt_client_mock = unittest.mock.AsyncMock()
  321. mqtt_client_mock.messages.__aiter__.return_value = [
  322. aiomqtt.Message(
  323. topic=mqtt_topic_prefix + "/poweroff",
  324. payload=b"some-payload",
  325. qos=0,
  326. retain=False,
  327. mid=42 // 2,
  328. properties=None,
  329. )
  330. ]
  331. with unittest.mock.patch(
  332. "systemctl_mqtt._dbus.login_manager.schedule_shutdown"
  333. ) as schedule_shutdown_mock, caplog.at_level(logging.DEBUG):
  334. await systemctl_mqtt._mqtt_message_loop(
  335. state=state, mqtt_client=mqtt_client_mock
  336. )
  337. assert sorted(mqtt_client_mock.subscribe.await_args_list) == [
  338. unittest.mock.call(mqtt_topic_prefix + "/lock-all-sessions"),
  339. unittest.mock.call(mqtt_topic_prefix + "/poweroff"),
  340. unittest.mock.call(mqtt_topic_prefix + "/suspend"),
  341. ]
  342. schedule_shutdown_mock.assert_called_once_with(
  343. action="poweroff", delay=datetime.timedelta(seconds=21)
  344. )
  345. assert [
  346. t for t in caplog.record_tuples[2:] if not t[2].startswith("subscribing to ")
  347. ] == [
  348. (
  349. "systemctl_mqtt",
  350. logging.DEBUG,
  351. f"received message on topic '{mqtt_topic_prefix}/poweroff': b'some-payload'",
  352. ),
  353. ]
  354. @pytest.mark.asyncio
  355. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  356. async def test__mqtt_message_loop_retained(
  357. caplog: pytest.LogCaptureFixture, mqtt_topic_prefix: str
  358. ) -> None:
  359. state = systemctl_mqtt._State(
  360. mqtt_topic_prefix=mqtt_topic_prefix,
  361. homeassistant_discovery_prefix="homeassistant",
  362. homeassistant_discovery_object_id="whatever",
  363. poweroff_delay=datetime.timedelta(seconds=21),
  364. monitored_system_unit_names=[],
  365. )
  366. mqtt_client_mock = unittest.mock.AsyncMock()
  367. mqtt_client_mock.messages.__aiter__.return_value = [
  368. aiomqtt.Message(
  369. topic=mqtt_topic_prefix + "/poweroff",
  370. payload=b"some-payload",
  371. qos=0,
  372. retain=True,
  373. mid=42 // 2,
  374. properties=None,
  375. )
  376. ]
  377. with unittest.mock.patch(
  378. "systemctl_mqtt._dbus.login_manager.schedule_shutdown"
  379. ) as schedule_shutdown_mock, caplog.at_level(logging.DEBUG):
  380. await systemctl_mqtt._mqtt_message_loop(
  381. state=state, mqtt_client=mqtt_client_mock
  382. )
  383. schedule_shutdown_mock.assert_not_called()
  384. assert [
  385. t for t in caplog.record_tuples[2:] if not t[2].startswith("subscribing to ")
  386. ] == [
  387. (
  388. "systemctl_mqtt",
  389. logging.INFO,
  390. "ignoring retained message on topic 'systemctl/host/poweroff'",
  391. ),
  392. ]
  393. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "systemd/raspberrypi"])
  394. @pytest.mark.parametrize("unit_name", ["foo.service", "bar.service"])
  395. def test_state_get_system_unit_active_state_mqtt_topic(
  396. mqtt_topic_prefix: str, unit_name: str
  397. ) -> None:
  398. state = systemctl_mqtt._State(
  399. mqtt_topic_prefix=mqtt_topic_prefix,
  400. homeassistant_discovery_prefix="homeassistant",
  401. homeassistant_discovery_object_id="whatever",
  402. poweroff_delay=datetime.timedelta(seconds=21),
  403. monitored_system_unit_names=[],
  404. )
  405. assert (
  406. state.get_system_unit_active_state_mqtt_topic(unit_name=unit_name)
  407. == f"{mqtt_topic_prefix}/unit/system/{unit_name}/active-state"
  408. )