test_mqtt.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import datetime
  18. import logging
  19. import threading
  20. import time
  21. import unittest.mock
  22. import jeepney.fds
  23. import jeepney.low_level
  24. import paho.mqtt.client
  25. import pytest
  26. from paho.mqtt.client import MQTTMessage
  27. import systemctl_mqtt
  28. # pylint: disable=protected-access,too-many-positional-arguments
  29. @pytest.mark.asyncio
  30. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  31. @pytest.mark.parametrize("mqtt_port", [1833])
  32. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  33. @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
  34. @pytest.mark.parametrize("homeassistant_discovery_object_id", ["host", "node"])
  35. async def test__run(
  36. caplog,
  37. mqtt_host,
  38. mqtt_port,
  39. mqtt_topic_prefix,
  40. homeassistant_discovery_prefix,
  41. homeassistant_discovery_object_id,
  42. ):
  43. # pylint: disable=too-many-locals,too-many-arguments
  44. caplog.set_level(logging.DEBUG)
  45. login_manager_mock = unittest.mock.MagicMock()
  46. with unittest.mock.patch(
  47. "socket.create_connection"
  48. ) as create_socket_mock, unittest.mock.patch(
  49. "ssl.SSLContext.wrap_socket", autospec=True
  50. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  51. "paho.mqtt.client.Client.loop_forever", autospec=True
  52. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  53. "systemctl_mqtt._dbus.get_login_manager_proxy", return_value=login_manager_mock
  54. ), unittest.mock.patch(
  55. "systemctl_mqtt._dbus_signal_loop"
  56. ) as dbus_signal_loop_mock:
  57. ssl_wrap_socket_mock.return_value.send = len
  58. login_manager_mock.Inhibit.return_value = (jeepney.fds.FileDescriptor(-1),)
  59. login_manager_mock.Get.return_value = (("b", False),)
  60. await systemctl_mqtt._run(
  61. mqtt_host=mqtt_host,
  62. mqtt_port=mqtt_port,
  63. mqtt_username=None,
  64. mqtt_password=None,
  65. mqtt_topic_prefix=mqtt_topic_prefix,
  66. homeassistant_discovery_prefix=homeassistant_discovery_prefix,
  67. homeassistant_discovery_object_id=homeassistant_discovery_object_id,
  68. poweroff_delay=datetime.timedelta(),
  69. )
  70. assert caplog.records[0].levelno == logging.INFO
  71. assert caplog.records[0].message == (
  72. f"connecting to MQTT broker {mqtt_host}:{mqtt_port} (TLS enabled)"
  73. )
  74. # correct remote?
  75. create_socket_mock.assert_called_once()
  76. create_socket_args, _ = create_socket_mock.call_args
  77. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  78. # ssl enabled?
  79. ssl_wrap_socket_mock.assert_called_once()
  80. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  81. assert ssl_context.check_hostname is True
  82. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  83. # loop started?
  84. while threading.active_count() > 1:
  85. time.sleep(0.01)
  86. mqtt_loop_forever_mock.assert_called_once()
  87. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  88. assert mqtt_client._tls_insecure is False
  89. # credentials
  90. assert mqtt_client._username is None
  91. assert mqtt_client._password is None
  92. # connect callback
  93. caplog.clear()
  94. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  95. with unittest.mock.patch(
  96. "paho.mqtt.client.Client.subscribe"
  97. ) as mqtt_subscribe_mock:
  98. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  99. login_manager_mock.Inhibit.assert_called_once_with(
  100. what="shutdown",
  101. who="systemctl-mqtt",
  102. why="Report shutdown via MQTT",
  103. mode="delay",
  104. )
  105. login_manager_mock.Get.assert_called_once_with("PreparingForShutdown")
  106. assert sorted(mqtt_subscribe_mock.call_args_list) == [
  107. unittest.mock.call(mqtt_topic_prefix + "/lock-all-sessions"),
  108. unittest.mock.call(mqtt_topic_prefix + "/poweroff"),
  109. unittest.mock.call(mqtt_topic_prefix + "/suspend"),
  110. ]
  111. assert mqtt_client.on_message is None
  112. for suffix in ("poweroff", "lock-all-sessions"):
  113. assert ( # pylint: disable=comparison-with-callable
  114. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/" + suffix]
  115. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  116. suffix
  117. ].mqtt_message_callback
  118. )
  119. assert caplog.records[0].levelno == logging.DEBUG
  120. assert (
  121. caplog.records[0].message == f"connected to MQTT broker {mqtt_host}:{mqtt_port}"
  122. )
  123. assert caplog.records[1].levelno == logging.DEBUG
  124. assert caplog.records[1].message == "acquired shutdown inhibitor lock"
  125. assert caplog.records[2].levelno == logging.INFO
  126. assert (
  127. caplog.records[2].message
  128. == f"publishing 'false' on {mqtt_topic_prefix}/preparing-for-shutdown"
  129. )
  130. assert caplog.records[3].levelno == logging.DEBUG
  131. assert (
  132. caplog.records[3].message
  133. == "publishing home assistant config on "
  134. + homeassistant_discovery_prefix
  135. + "/device/"
  136. + homeassistant_discovery_object_id
  137. + "/config"
  138. )
  139. assert all(r.levelno == logging.INFO for r in caplog.records[4::2])
  140. assert {r.message for r in caplog.records[4::2]} == {
  141. f"subscribing to {mqtt_topic_prefix}/{s}"
  142. for s in ("poweroff", "lock-all-sessions", "suspend")
  143. }
  144. assert all(r.levelno == logging.DEBUG for r in caplog.records[5::2])
  145. assert {r.message for r in caplog.records[5::2]} == {
  146. f"registered MQTT callback for topic {mqtt_topic_prefix}/{s}"
  147. f" triggering {systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[s]}"
  148. for s in ("poweroff", "lock-all-sessions", "suspend")
  149. }
  150. dbus_signal_loop_mock.assert_awaited_once()
  151. # waited for mqtt loop to stop?
  152. assert mqtt_client._thread_terminate
  153. assert mqtt_client._thread is None
  154. @pytest.mark.asyncio
  155. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  156. @pytest.mark.parametrize("mqtt_port", [1833])
  157. @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
  158. async def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
  159. caplog.set_level(logging.INFO)
  160. with unittest.mock.patch(
  161. "paho.mqtt.client.Client"
  162. ) as mqtt_client_class, unittest.mock.patch(
  163. "systemctl_mqtt._dbus_signal_loop"
  164. ) as dbus_signal_loop_mock:
  165. await systemctl_mqtt._run(
  166. mqtt_host=mqtt_host,
  167. mqtt_port=mqtt_port,
  168. mqtt_disable_tls=mqtt_disable_tls,
  169. mqtt_username=None,
  170. mqtt_password=None,
  171. mqtt_topic_prefix="systemctl/hosts",
  172. homeassistant_discovery_prefix="homeassistant",
  173. homeassistant_discovery_object_id="host",
  174. poweroff_delay=datetime.timedelta(),
  175. )
  176. assert caplog.records[0].levelno == logging.INFO
  177. assert caplog.records[0].message == (
  178. f"connecting to MQTT broker {mqtt_host}:{mqtt_port}"
  179. f" (TLS {'disabled' if mqtt_disable_tls else 'enabled'})"
  180. )
  181. if mqtt_disable_tls:
  182. mqtt_client_class().tls_set.assert_not_called()
  183. else:
  184. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  185. dbus_signal_loop_mock.assert_awaited_once()
  186. @pytest.mark.asyncio
  187. async def test__run_tls_default():
  188. with unittest.mock.patch(
  189. "paho.mqtt.client.Client"
  190. ) as mqtt_client_class, unittest.mock.patch(
  191. "systemctl_mqtt._dbus_signal_loop"
  192. ) as dbus_signal_loop_mock:
  193. await systemctl_mqtt._run(
  194. mqtt_host="mqtt-broker.local",
  195. mqtt_port=1833,
  196. # mqtt_disable_tls default,
  197. mqtt_username=None,
  198. mqtt_password=None,
  199. mqtt_topic_prefix="systemctl/hosts",
  200. homeassistant_discovery_prefix="homeassistant",
  201. homeassistant_discovery_object_id="host",
  202. poweroff_delay=datetime.timedelta(),
  203. )
  204. # enabled by default
  205. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  206. dbus_signal_loop_mock.assert_awaited_once()
  207. @pytest.mark.asyncio
  208. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  209. @pytest.mark.parametrize("mqtt_port", [1833])
  210. @pytest.mark.parametrize("mqtt_username", ["me"])
  211. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  212. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  213. async def test__run_authentication(
  214. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  215. ):
  216. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  217. "ssl.SSLContext.wrap_socket"
  218. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  219. "paho.mqtt.client.Client.loop_forever", autospec=True
  220. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  221. "systemctl_mqtt._dbus.get_login_manager_proxy"
  222. ), unittest.mock.patch(
  223. "systemctl_mqtt._dbus_signal_loop"
  224. ) as dbus_signal_loop_mock:
  225. ssl_wrap_socket_mock.return_value.send = len
  226. await systemctl_mqtt._run(
  227. mqtt_host=mqtt_host,
  228. mqtt_port=mqtt_port,
  229. mqtt_username=mqtt_username,
  230. mqtt_password=mqtt_password,
  231. mqtt_topic_prefix=mqtt_topic_prefix,
  232. homeassistant_discovery_prefix="discovery-prefix",
  233. homeassistant_discovery_object_id="node-id",
  234. poweroff_delay=datetime.timedelta(),
  235. )
  236. mqtt_loop_forever_mock.assert_called_once()
  237. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  238. assert mqtt_client._username.decode() == mqtt_username
  239. if mqtt_password:
  240. assert mqtt_client._password.decode() == mqtt_password
  241. else:
  242. assert mqtt_client._password is None
  243. dbus_signal_loop_mock.assert_awaited_once()
  244. @pytest.mark.asyncio
  245. async def _initialize_mqtt_client(
  246. mqtt_host, mqtt_port, mqtt_topic_prefix
  247. ) -> paho.mqtt.client.Client:
  248. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  249. "ssl.SSLContext.wrap_socket"
  250. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  251. "paho.mqtt.client.Client.loop_forever", autospec=True
  252. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  253. "systemctl_mqtt._dbus.get_login_manager_proxy"
  254. ) as get_login_manager_mock, unittest.mock.patch(
  255. "systemctl_mqtt._dbus_signal_loop"
  256. ):
  257. ssl_wrap_socket_mock.return_value.send = len
  258. get_login_manager_mock.return_value.Inhibit.return_value = (
  259. jeepney.fds.FileDescriptor(-1),
  260. )
  261. get_login_manager_mock.return_value.Get.return_value = (("b", True),)
  262. await systemctl_mqtt._run(
  263. mqtt_host=mqtt_host,
  264. mqtt_port=mqtt_port,
  265. mqtt_username=None,
  266. mqtt_password=None,
  267. mqtt_topic_prefix=mqtt_topic_prefix,
  268. homeassistant_discovery_prefix="discovery-prefix",
  269. homeassistant_discovery_object_id="node-id",
  270. poweroff_delay=datetime.timedelta(),
  271. )
  272. while threading.active_count() > 1:
  273. time.sleep(0.01)
  274. mqtt_loop_forever_mock.assert_called_once()
  275. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  276. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  277. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  278. return mqtt_client
  279. @pytest.mark.asyncio
  280. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  281. @pytest.mark.parametrize("mqtt_port", [1833])
  282. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  283. async def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  284. mqtt_client = await _initialize_mqtt_client(
  285. mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
  286. )
  287. caplog.clear()
  288. caplog.set_level(logging.DEBUG)
  289. poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
  290. with unittest.mock.patch.object(
  291. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  292. ) as poweroff_trigger_mock:
  293. mqtt_client._handle_on_message(poweroff_message)
  294. poweroff_trigger_mock.assert_called_once_with(state=mqtt_client._userdata)
  295. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  296. assert (
  297. caplog.records[0].message
  298. == f"received topic={poweroff_message.topic} payload=b''"
  299. )
  300. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  301. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  302. @pytest.mark.asyncio
  303. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  304. @pytest.mark.parametrize("mqtt_port", [1833])
  305. @pytest.mark.parametrize("mqtt_password", ["secret"])
  306. async def test__run_authentication_missing_username(
  307. mqtt_host, mqtt_port, mqtt_password
  308. ):
  309. with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
  310. "systemctl_mqtt._dbus.get_login_manager_proxy"
  311. ), unittest.mock.patch("systemctl_mqtt._dbus_signal_loop") as dbus_signal_loop_mock:
  312. with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
  313. await systemctl_mqtt._run(
  314. mqtt_host=mqtt_host,
  315. mqtt_port=mqtt_port,
  316. mqtt_username=None,
  317. mqtt_password=mqtt_password,
  318. mqtt_topic_prefix="prefix",
  319. homeassistant_discovery_prefix="discovery-prefix",
  320. homeassistant_discovery_object_id="node-id",
  321. poweroff_delay=datetime.timedelta(),
  322. )
  323. dbus_signal_loop_mock.assert_not_called()
  324. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  325. @pytest.mark.parametrize("payload", [b"", b"junk"])
  326. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  327. message = MQTTMessage(topic=mqtt_topic.encode())
  328. message.payload = payload
  329. with unittest.mock.patch.object(
  330. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  331. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  332. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  333. "poweroff"
  334. ].mqtt_message_callback(
  335. None, "state_dummy", message # type: ignore
  336. )
  337. trigger_mock.assert_called_once_with(state="state_dummy")
  338. assert len(caplog.records) == 3
  339. assert caplog.records[0].levelno == logging.DEBUG
  340. assert caplog.records[0].message == (
  341. f"received topic={mqtt_topic} payload={payload!r}"
  342. )
  343. assert caplog.records[1].levelno == logging.DEBUG
  344. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  345. assert caplog.records[2].levelno == logging.DEBUG
  346. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  347. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  348. @pytest.mark.parametrize("payload", [b"", b"junk"])
  349. def test_mqtt_message_callback_poweroff_retained(
  350. caplog, mqtt_topic: str, payload: bytes
  351. ):
  352. message = MQTTMessage(topic=mqtt_topic.encode())
  353. message.payload = payload
  354. message.retain = True
  355. with unittest.mock.patch.object(
  356. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  357. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  358. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  359. "poweroff"
  360. ].mqtt_message_callback(
  361. None, None, message # type: ignore
  362. )
  363. trigger_mock.assert_not_called()
  364. assert len(caplog.records) == 2
  365. assert caplog.records[0].levelno == logging.DEBUG
  366. assert caplog.records[0].message == (
  367. f"received topic={mqtt_topic} payload={payload!r}"
  368. )
  369. assert caplog.records[1].levelno == logging.INFO
  370. assert caplog.records[1].message == "ignoring retained message"
  371. @pytest.mark.parametrize("active", [True, False])
  372. @pytest.mark.parametrize("block", [True, False])
  373. def test__publish_preparing_for_shutdown_blocking(active: bool, block: bool) -> None:
  374. login_manager_mock = unittest.mock.MagicMock()
  375. login_manager_mock.Get.return_value = (("b", active),)
  376. with unittest.mock.patch(
  377. "systemctl_mqtt._dbus.get_login_manager_proxy", return_value=login_manager_mock
  378. ):
  379. state = systemctl_mqtt._State(
  380. mqtt_topic_prefix="prefix",
  381. homeassistant_discovery_prefix="prefix",
  382. homeassistant_discovery_object_id="object-id",
  383. poweroff_delay=datetime.timedelta(),
  384. )
  385. mqtt_client_mock = unittest.mock.MagicMock()
  386. state._publish_preparing_for_shutdown(
  387. mqtt_client=mqtt_client_mock, active=active, block=block
  388. )
  389. mqtt_client_mock.publish.assert_called_once_with(
  390. topic="prefix/preparing-for-shutdown",
  391. payload="true" if active else "false",
  392. retain=True,
  393. )
  394. msg_info = mqtt_client_mock.publish.return_value
  395. if block:
  396. msg_info.wait_for_publish.assert_called_once()
  397. else:
  398. msg_info.wait_for_publish.assert_not_called()