test_mqtt.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import contextlib
  18. import datetime
  19. import logging
  20. import threading
  21. import time
  22. import typing
  23. import unittest.mock
  24. import jeepney.fds
  25. import jeepney.low_level
  26. import paho.mqtt.client
  27. import pytest
  28. from paho.mqtt.client import MQTTMessage
  29. import systemctl_mqtt
  30. # pylint: disable=protected-access,too-many-positional-arguments
  31. @contextlib.contextmanager
  32. def mock_open_dbus_connection() -> typing.Iterator[unittest.mock.MagicMock]:
  33. with unittest.mock.patch("jeepney.io.blocking.open_dbus_connection") as mock:
  34. add_match_reply = unittest.mock.Mock()
  35. add_match_reply.body = ()
  36. mock.return_value.send_and_get_reply.return_value = add_match_reply
  37. mock.return_value.recv_until_filtered.side_effect = []
  38. yield mock
  39. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  40. @pytest.mark.parametrize("mqtt_port", [1833])
  41. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  42. @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
  43. @pytest.mark.parametrize("homeassistant_discovery_object_id", ["host", "node"])
  44. def test__run(
  45. caplog,
  46. mqtt_host,
  47. mqtt_port,
  48. mqtt_topic_prefix,
  49. homeassistant_discovery_prefix,
  50. homeassistant_discovery_object_id,
  51. ):
  52. # pylint: disable=too-many-locals,too-many-arguments
  53. caplog.set_level(logging.DEBUG)
  54. login_manager_mock = unittest.mock.MagicMock()
  55. with unittest.mock.patch(
  56. "socket.create_connection"
  57. ) as create_socket_mock, unittest.mock.patch(
  58. "ssl.SSLContext.wrap_socket", autospec=True
  59. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  60. "paho.mqtt.client.Client.loop_forever", autospec=True
  61. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  62. "systemctl_mqtt._dbus.get_login_manager_proxy", return_value=login_manager_mock
  63. ), mock_open_dbus_connection() as open_dbus_connection_mock:
  64. ssl_wrap_socket_mock.return_value.send = len
  65. login_manager_mock.Inhibit.return_value = (jeepney.fds.FileDescriptor(-1),)
  66. login_manager_mock.Get.return_value = (("b", False),)
  67. with pytest.raises(StopIteration):
  68. systemctl_mqtt._run(
  69. mqtt_host=mqtt_host,
  70. mqtt_port=mqtt_port,
  71. mqtt_username=None,
  72. mqtt_password=None,
  73. mqtt_topic_prefix=mqtt_topic_prefix,
  74. homeassistant_discovery_prefix=homeassistant_discovery_prefix,
  75. homeassistant_discovery_object_id=homeassistant_discovery_object_id,
  76. poweroff_delay=datetime.timedelta(),
  77. )
  78. assert caplog.records[0].levelno == logging.INFO
  79. assert caplog.records[0].message == (
  80. f"connecting to MQTT broker {mqtt_host}:{mqtt_port} (TLS enabled)"
  81. )
  82. # correct remote?
  83. create_socket_mock.assert_called_once()
  84. create_socket_args, _ = create_socket_mock.call_args
  85. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  86. # ssl enabled?
  87. ssl_wrap_socket_mock.assert_called_once()
  88. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  89. assert ssl_context.check_hostname is True
  90. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  91. # loop started?
  92. while threading.active_count() > 1:
  93. time.sleep(0.01)
  94. mqtt_loop_forever_mock.assert_called_once()
  95. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  96. assert mqtt_client._tls_insecure is False
  97. # credentials
  98. assert mqtt_client._username is None
  99. assert mqtt_client._password is None
  100. # connect callback
  101. caplog.clear()
  102. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  103. with unittest.mock.patch(
  104. "paho.mqtt.client.Client.subscribe"
  105. ) as mqtt_subscribe_mock:
  106. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  107. open_dbus_connection_mock.assert_called_once_with(bus="SYSTEM")
  108. login_manager_mock.Inhibit.assert_called_once_with(
  109. what="shutdown",
  110. who="systemctl-mqtt",
  111. why="Report shutdown via MQTT",
  112. mode="delay",
  113. )
  114. login_manager_mock.Get.assert_called_once_with("PreparingForShutdown")
  115. open_dbus_connection_mock.return_value.send_and_get_reply.assert_called_once()
  116. assert sorted(mqtt_subscribe_mock.call_args_list) == [
  117. unittest.mock.call(mqtt_topic_prefix + "/lock-all-sessions"),
  118. unittest.mock.call(mqtt_topic_prefix + "/poweroff"),
  119. unittest.mock.call(mqtt_topic_prefix + "/suspend"),
  120. ]
  121. assert mqtt_client.on_message is None
  122. for suffix in ("poweroff", "lock-all-sessions"):
  123. assert ( # pylint: disable=comparison-with-callable
  124. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/" + suffix]
  125. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  126. suffix
  127. ].mqtt_message_callback
  128. )
  129. assert caplog.records[0].levelno == logging.DEBUG
  130. assert (
  131. caplog.records[0].message == f"connected to MQTT broker {mqtt_host}:{mqtt_port}"
  132. )
  133. assert caplog.records[1].levelno == logging.DEBUG
  134. assert caplog.records[1].message == "acquired shutdown inhibitor lock"
  135. assert caplog.records[2].levelno == logging.INFO
  136. assert (
  137. caplog.records[2].message
  138. == f"publishing 'false' on {mqtt_topic_prefix}/preparing-for-shutdown"
  139. )
  140. assert caplog.records[3].levelno == logging.DEBUG
  141. assert (
  142. caplog.records[3].message
  143. == "publishing home assistant config on "
  144. + homeassistant_discovery_prefix
  145. + "/device/"
  146. + homeassistant_discovery_object_id
  147. + "/config"
  148. )
  149. assert all(r.levelno == logging.INFO for r in caplog.records[4::2])
  150. assert {r.message for r in caplog.records[4::2]} == {
  151. f"subscribing to {mqtt_topic_prefix}/{s}"
  152. for s in ("poweroff", "lock-all-sessions", "suspend")
  153. }
  154. assert all(r.levelno == logging.DEBUG for r in caplog.records[5::2])
  155. assert {r.message for r in caplog.records[5::2]} == {
  156. f"registered MQTT callback for topic {mqtt_topic_prefix}/{s}"
  157. f" triggering {systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[s]}"
  158. for s in ("poweroff", "lock-all-sessions", "suspend")
  159. }
  160. open_dbus_connection_mock.return_value.filter.assert_called_once()
  161. # waited for mqtt loop to stop?
  162. assert mqtt_client._thread_terminate
  163. assert mqtt_client._thread is None
  164. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  165. @pytest.mark.parametrize("mqtt_port", [1833])
  166. @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
  167. def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
  168. caplog.set_level(logging.INFO)
  169. with unittest.mock.patch(
  170. "paho.mqtt.client.Client"
  171. ) as mqtt_client_class, mock_open_dbus_connection(), pytest.raises(StopIteration):
  172. systemctl_mqtt._run(
  173. mqtt_host=mqtt_host,
  174. mqtt_port=mqtt_port,
  175. mqtt_disable_tls=mqtt_disable_tls,
  176. mqtt_username=None,
  177. mqtt_password=None,
  178. mqtt_topic_prefix="systemctl/hosts",
  179. homeassistant_discovery_prefix="homeassistant",
  180. homeassistant_discovery_object_id="host",
  181. poweroff_delay=datetime.timedelta(),
  182. )
  183. assert caplog.records[0].levelno == logging.INFO
  184. assert caplog.records[0].message == (
  185. f"connecting to MQTT broker {mqtt_host}:{mqtt_port}"
  186. f" (TLS {'disabled' if mqtt_disable_tls else 'enabled'})"
  187. )
  188. if mqtt_disable_tls:
  189. mqtt_client_class().tls_set.assert_not_called()
  190. else:
  191. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  192. def test__run_tls_default():
  193. with unittest.mock.patch(
  194. "paho.mqtt.client.Client"
  195. ) as mqtt_client_class, mock_open_dbus_connection(), pytest.raises(StopIteration):
  196. systemctl_mqtt._run(
  197. mqtt_host="mqtt-broker.local",
  198. mqtt_port=1833,
  199. # mqtt_disable_tls default,
  200. mqtt_username=None,
  201. mqtt_password=None,
  202. mqtt_topic_prefix="systemctl/hosts",
  203. homeassistant_discovery_prefix="homeassistant",
  204. homeassistant_discovery_object_id="host",
  205. poweroff_delay=datetime.timedelta(),
  206. )
  207. # enabled by default
  208. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  209. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  210. @pytest.mark.parametrize("mqtt_port", [1833])
  211. @pytest.mark.parametrize("mqtt_username", ["me"])
  212. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  213. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  214. def test__run_authentication(
  215. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  216. ):
  217. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  218. "ssl.SSLContext.wrap_socket"
  219. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  220. "paho.mqtt.client.Client.loop_forever", autospec=True
  221. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  222. "systemctl_mqtt._dbus.get_login_manager_proxy"
  223. ), mock_open_dbus_connection(), pytest.raises(
  224. StopIteration
  225. ):
  226. ssl_wrap_socket_mock.return_value.send = len
  227. systemctl_mqtt._run(
  228. mqtt_host=mqtt_host,
  229. mqtt_port=mqtt_port,
  230. mqtt_username=mqtt_username,
  231. mqtt_password=mqtt_password,
  232. mqtt_topic_prefix=mqtt_topic_prefix,
  233. homeassistant_discovery_prefix="discovery-prefix",
  234. homeassistant_discovery_object_id="node-id",
  235. poweroff_delay=datetime.timedelta(),
  236. )
  237. mqtt_loop_forever_mock.assert_called_once()
  238. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  239. assert mqtt_client._username.decode() == mqtt_username
  240. if mqtt_password:
  241. assert mqtt_client._password.decode() == mqtt_password
  242. else:
  243. assert mqtt_client._password is None
  244. def _initialize_mqtt_client(
  245. mqtt_host, mqtt_port, mqtt_topic_prefix
  246. ) -> paho.mqtt.client.Client:
  247. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  248. "ssl.SSLContext.wrap_socket"
  249. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  250. "paho.mqtt.client.Client.loop_forever", autospec=True
  251. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  252. "systemctl_mqtt._dbus.get_login_manager_proxy"
  253. ) as get_login_manager_mock, mock_open_dbus_connection(), pytest.raises(
  254. StopIteration
  255. ):
  256. ssl_wrap_socket_mock.return_value.send = len
  257. get_login_manager_mock.return_value.Inhibit.return_value = (
  258. jeepney.fds.FileDescriptor(-1),
  259. )
  260. get_login_manager_mock.return_value.Get.return_value = (("b", True),)
  261. systemctl_mqtt._run(
  262. mqtt_host=mqtt_host,
  263. mqtt_port=mqtt_port,
  264. mqtt_username=None,
  265. mqtt_password=None,
  266. mqtt_topic_prefix=mqtt_topic_prefix,
  267. homeassistant_discovery_prefix="discovery-prefix",
  268. homeassistant_discovery_object_id="node-id",
  269. poweroff_delay=datetime.timedelta(),
  270. )
  271. while threading.active_count() > 1:
  272. time.sleep(0.01)
  273. mqtt_loop_forever_mock.assert_called_once()
  274. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  275. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  276. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  277. return mqtt_client
  278. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  279. @pytest.mark.parametrize("mqtt_port", [1833])
  280. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  281. def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  282. mqtt_client = _initialize_mqtt_client(
  283. mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
  284. )
  285. caplog.clear()
  286. caplog.set_level(logging.DEBUG)
  287. poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
  288. with unittest.mock.patch.object(
  289. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  290. ) as poweroff_trigger_mock:
  291. mqtt_client._handle_on_message(poweroff_message)
  292. poweroff_trigger_mock.assert_called_once_with(state=mqtt_client._userdata)
  293. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  294. assert (
  295. caplog.records[0].message
  296. == f"received topic={poweroff_message.topic} payload=b''"
  297. )
  298. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  299. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  300. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  301. @pytest.mark.parametrize("mqtt_port", [1833])
  302. @pytest.mark.parametrize("mqtt_password", ["secret"])
  303. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  304. with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
  305. "systemctl_mqtt._dbus.get_login_manager_proxy"
  306. ), mock_open_dbus_connection():
  307. with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
  308. systemctl_mqtt._run(
  309. mqtt_host=mqtt_host,
  310. mqtt_port=mqtt_port,
  311. mqtt_username=None,
  312. mqtt_password=mqtt_password,
  313. mqtt_topic_prefix="prefix",
  314. homeassistant_discovery_prefix="discovery-prefix",
  315. homeassistant_discovery_object_id="node-id",
  316. poweroff_delay=datetime.timedelta(),
  317. )
  318. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  319. @pytest.mark.parametrize("payload", [b"", b"junk"])
  320. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  321. message = MQTTMessage(topic=mqtt_topic.encode())
  322. message.payload = payload
  323. with unittest.mock.patch.object(
  324. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  325. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  326. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  327. "poweroff"
  328. ].mqtt_message_callback(
  329. None, "state_dummy", message # type: ignore
  330. )
  331. trigger_mock.assert_called_once_with(state="state_dummy")
  332. assert len(caplog.records) == 3
  333. assert caplog.records[0].levelno == logging.DEBUG
  334. assert caplog.records[0].message == (
  335. f"received topic={mqtt_topic} payload={payload!r}"
  336. )
  337. assert caplog.records[1].levelno == logging.DEBUG
  338. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  339. assert caplog.records[2].levelno == logging.DEBUG
  340. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  341. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  342. @pytest.mark.parametrize("payload", [b"", b"junk"])
  343. def test_mqtt_message_callback_poweroff_retained(
  344. caplog, mqtt_topic: str, payload: bytes
  345. ):
  346. message = MQTTMessage(topic=mqtt_topic.encode())
  347. message.payload = payload
  348. message.retain = True
  349. with unittest.mock.patch.object(
  350. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  351. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  352. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  353. "poweroff"
  354. ].mqtt_message_callback(
  355. None, None, message # type: ignore
  356. )
  357. trigger_mock.assert_not_called()
  358. assert len(caplog.records) == 2
  359. assert caplog.records[0].levelno == logging.DEBUG
  360. assert caplog.records[0].message == (
  361. f"received topic={mqtt_topic} payload={payload!r}"
  362. )
  363. assert caplog.records[1].levelno == logging.INFO
  364. assert caplog.records[1].message == "ignoring retained message"
  365. @pytest.mark.parametrize("active", [True, False])
  366. @pytest.mark.parametrize("block", [True, False])
  367. def test__publish_preparing_for_shutdown_blocking(active: bool, block: bool) -> None:
  368. login_manager_mock = unittest.mock.MagicMock()
  369. login_manager_mock.Get.return_value = (("b", active),)
  370. with unittest.mock.patch(
  371. "systemctl_mqtt._dbus.get_login_manager_proxy", return_value=login_manager_mock
  372. ):
  373. state = systemctl_mqtt._State(
  374. mqtt_topic_prefix="prefix",
  375. homeassistant_discovery_prefix="prefix",
  376. homeassistant_discovery_object_id="object-id",
  377. poweroff_delay=datetime.timedelta(),
  378. )
  379. mqtt_client_mock = unittest.mock.MagicMock()
  380. state._publish_preparing_for_shutdown(
  381. mqtt_client=mqtt_client_mock, active=active, block=block
  382. )
  383. mqtt_client_mock.publish.assert_called_once_with(
  384. topic="prefix/preparing-for-shutdown",
  385. payload="true" if active else "false",
  386. retain=True,
  387. )
  388. msg_info = mqtt_client_mock.publish.return_value
  389. if block:
  390. msg_info.wait_for_publish.assert_called_once()
  391. else:
  392. msg_info.wait_for_publish.assert_not_called()