test_mqtt.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import logging
  18. import threading
  19. import time
  20. import unittest.mock
  21. import dbus
  22. import paho.mqtt.client
  23. import pytest
  24. from paho.mqtt.client import MQTTMessage
  25. import systemctl_mqtt
  26. # pylint: disable=protected-access
  27. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  28. @pytest.mark.parametrize("mqtt_port", [1833])
  29. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  30. @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
  31. @pytest.mark.parametrize("homeassistant_node_id", ["host", "node"])
  32. def test__run(
  33. caplog,
  34. mqtt_host,
  35. mqtt_port,
  36. mqtt_topic_prefix,
  37. homeassistant_discovery_prefix,
  38. homeassistant_node_id,
  39. ):
  40. # pylint: disable=too-many-locals,too-many-arguments
  41. caplog.set_level(logging.DEBUG)
  42. with unittest.mock.patch(
  43. "socket.create_connection"
  44. ) as create_socket_mock, unittest.mock.patch(
  45. "ssl.SSLContext.wrap_socket", autospec=True,
  46. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  47. "paho.mqtt.client.Client.loop_forever", autospec=True,
  48. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  49. "gi.repository.GLib.MainLoop.run"
  50. ) as glib_loop_mock, unittest.mock.patch(
  51. "systemctl_mqtt._dbus.get_login_manager"
  52. ) as get_login_manager_mock:
  53. ssl_wrap_socket_mock.return_value.send = len
  54. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  55. systemctl_mqtt._run(
  56. mqtt_host=mqtt_host,
  57. mqtt_port=mqtt_port,
  58. mqtt_username=None,
  59. mqtt_password=None,
  60. mqtt_topic_prefix=mqtt_topic_prefix,
  61. homeassistant_discovery_prefix=homeassistant_discovery_prefix,
  62. homeassistant_node_id=homeassistant_node_id,
  63. )
  64. assert caplog.records[0].levelno == logging.INFO
  65. assert caplog.records[0].message == (
  66. "connecting to MQTT broker {}:{} (TLS enabled)".format(mqtt_host, mqtt_port)
  67. )
  68. # correct remote?
  69. assert create_socket_mock.call_count == 1
  70. create_socket_args, _ = create_socket_mock.call_args
  71. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  72. # ssl enabled?
  73. assert ssl_wrap_socket_mock.call_count == 1
  74. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  75. assert ssl_context.check_hostname is True
  76. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  77. # loop started?
  78. while threading.active_count() > 1:
  79. time.sleep(0.01)
  80. assert mqtt_loop_forever_mock.call_count == 1
  81. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  82. assert mqtt_client._tls_insecure is False
  83. # credentials
  84. assert mqtt_client._username is None
  85. assert mqtt_client._password is None
  86. # connect callback
  87. caplog.clear()
  88. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  89. with unittest.mock.patch(
  90. "paho.mqtt.client.Client.subscribe"
  91. ) as mqtt_subscribe_mock:
  92. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  93. state = mqtt_client._userdata
  94. assert (
  95. state._login_manager.connect_to_signal.call_args[1]["signal_name"]
  96. == "PrepareForShutdown"
  97. )
  98. mqtt_subscribe_mock.assert_called_once_with(mqtt_topic_prefix + "/poweroff")
  99. assert mqtt_client.on_message is None
  100. assert ( # pylint: disable=comparison-with-callable
  101. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/poweroff"]
  102. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  103. "poweroff"
  104. ].mqtt_message_callback
  105. )
  106. assert caplog.records[0].levelno == logging.DEBUG
  107. assert caplog.records[0].message == "connected to MQTT broker {}:{}".format(
  108. mqtt_host, mqtt_port
  109. )
  110. assert caplog.records[1].levelno == logging.DEBUG
  111. assert caplog.records[1].message == "acquired shutdown inhibitor lock"
  112. assert caplog.records[2].levelno == logging.INFO
  113. assert caplog.records[2].message == "publishing 'false' on {}".format(
  114. mqtt_topic_prefix + "/preparing-for-shutdown"
  115. )
  116. assert caplog.records[3].levelno == logging.DEBUG
  117. assert (
  118. caplog.records[3].message
  119. == "publishing home assistant config on "
  120. + homeassistant_discovery_prefix
  121. + "/binary_sensor/"
  122. + homeassistant_node_id
  123. + "/preparing-for-shutdown/config"
  124. )
  125. assert caplog.records[4].levelno == logging.INFO
  126. assert caplog.records[4].message == "subscribing to {}".format(
  127. mqtt_topic_prefix + "/poweroff"
  128. )
  129. assert caplog.records[5].levelno == logging.DEBUG
  130. assert caplog.records[5].message == "registered MQTT callback for topic {}".format(
  131. mqtt_topic_prefix + "/poweroff"
  132. ) + " triggering {}".format(
  133. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"].action
  134. )
  135. # dbus loop started?
  136. glib_loop_mock.assert_called_once_with()
  137. # waited for mqtt loop to stop?
  138. assert mqtt_client._thread_terminate
  139. assert mqtt_client._thread is None
  140. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  141. @pytest.mark.parametrize("mqtt_port", [1833])
  142. @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
  143. def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
  144. caplog.set_level(logging.INFO)
  145. with unittest.mock.patch(
  146. "paho.mqtt.client.Client"
  147. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  148. systemctl_mqtt._run(
  149. mqtt_host=mqtt_host,
  150. mqtt_port=mqtt_port,
  151. mqtt_disable_tls=mqtt_disable_tls,
  152. mqtt_username=None,
  153. mqtt_password=None,
  154. mqtt_topic_prefix="systemctl/hosts",
  155. homeassistant_discovery_prefix="homeassistant",
  156. homeassistant_node_id="host",
  157. )
  158. assert caplog.records[0].levelno == logging.INFO
  159. assert caplog.records[0].message == (
  160. "connecting to MQTT broker {}:{} (TLS {})".format(
  161. mqtt_host, mqtt_port, "disabled" if mqtt_disable_tls else "enabled"
  162. )
  163. )
  164. if mqtt_disable_tls:
  165. mqtt_client_class().tls_set.assert_not_called()
  166. else:
  167. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  168. def test__run_tls_default():
  169. with unittest.mock.patch(
  170. "paho.mqtt.client.Client"
  171. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  172. systemctl_mqtt._run(
  173. mqtt_host="mqtt-broker.local",
  174. mqtt_port=1833,
  175. # mqtt_disable_tls default,
  176. mqtt_username=None,
  177. mqtt_password=None,
  178. mqtt_topic_prefix="systemctl/hosts",
  179. homeassistant_discovery_prefix="homeassistant",
  180. homeassistant_node_id="host",
  181. )
  182. # enabled by default
  183. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  184. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  185. @pytest.mark.parametrize("mqtt_port", [1833])
  186. @pytest.mark.parametrize("mqtt_username", ["me"])
  187. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  188. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  189. def test__run_authentication(
  190. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  191. ):
  192. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  193. "ssl.SSLContext.wrap_socket"
  194. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  195. "paho.mqtt.client.Client.loop_forever", autospec=True,
  196. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  197. "gi.repository.GLib.MainLoop.run"
  198. ), unittest.mock.patch(
  199. "systemctl_mqtt._dbus.get_login_manager"
  200. ):
  201. ssl_wrap_socket_mock.return_value.send = len
  202. systemctl_mqtt._run(
  203. mqtt_host=mqtt_host,
  204. mqtt_port=mqtt_port,
  205. mqtt_username=mqtt_username,
  206. mqtt_password=mqtt_password,
  207. mqtt_topic_prefix=mqtt_topic_prefix,
  208. homeassistant_discovery_prefix="discovery-prefix",
  209. homeassistant_node_id="node-id",
  210. )
  211. assert mqtt_loop_forever_mock.call_count == 1
  212. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  213. assert mqtt_client._username.decode() == mqtt_username
  214. if mqtt_password:
  215. assert mqtt_client._password.decode() == mqtt_password
  216. else:
  217. assert mqtt_client._password is None
  218. def _initialize_mqtt_client(
  219. mqtt_host, mqtt_port, mqtt_topic_prefix
  220. ) -> paho.mqtt.client.Client:
  221. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  222. "ssl.SSLContext.wrap_socket",
  223. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  224. "paho.mqtt.client.Client.loop_forever", autospec=True,
  225. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  226. "gi.repository.GLib.MainLoop.run"
  227. ), unittest.mock.patch(
  228. "systemctl_mqtt._dbus.get_login_manager"
  229. ) as get_login_manager_mock:
  230. ssl_wrap_socket_mock.return_value.send = len
  231. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  232. systemctl_mqtt._run(
  233. mqtt_host=mqtt_host,
  234. mqtt_port=mqtt_port,
  235. mqtt_username=None,
  236. mqtt_password=None,
  237. mqtt_topic_prefix=mqtt_topic_prefix,
  238. homeassistant_discovery_prefix="discovery-prefix",
  239. homeassistant_node_id="node-id",
  240. )
  241. while threading.active_count() > 1:
  242. time.sleep(0.01)
  243. assert mqtt_loop_forever_mock.call_count == 1
  244. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  245. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  246. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  247. return mqtt_client
  248. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  249. @pytest.mark.parametrize("mqtt_port", [1833])
  250. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  251. def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  252. mqtt_client = _initialize_mqtt_client(
  253. mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
  254. )
  255. caplog.clear()
  256. caplog.set_level(logging.DEBUG)
  257. poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
  258. with unittest.mock.patch.object(
  259. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  260. ) as poweroff_action_mock:
  261. mqtt_client._handle_on_message(poweroff_message)
  262. poweroff_action_mock.assert_called_once_with()
  263. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  264. assert caplog.records[0].message == "received topic={} payload=b''".format(
  265. poweroff_message.topic
  266. )
  267. assert caplog.records[1].message.startswith("executing action poweroff")
  268. assert caplog.records[2].message.startswith("completed action poweroff")
  269. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  270. @pytest.mark.parametrize("mqtt_port", [1833])
  271. @pytest.mark.parametrize("mqtt_password", ["secret"])
  272. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  273. with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
  274. "systemctl_mqtt._dbus.get_login_manager"
  275. ):
  276. with pytest.raises(ValueError):
  277. systemctl_mqtt._run(
  278. mqtt_host=mqtt_host,
  279. mqtt_port=mqtt_port,
  280. mqtt_username=None,
  281. mqtt_password=mqtt_password,
  282. mqtt_topic_prefix="prefix",
  283. homeassistant_discovery_prefix="discovery-prefix",
  284. homeassistant_node_id="node-id",
  285. )
  286. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  287. @pytest.mark.parametrize("payload", [b"", b"junk"])
  288. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  289. message = MQTTMessage(topic=mqtt_topic.encode())
  290. message.payload = payload
  291. with unittest.mock.patch.object(
  292. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  293. ) as action_mock, caplog.at_level(logging.DEBUG):
  294. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  295. "poweroff"
  296. ].mqtt_message_callback(
  297. None, None, message # type: ignore
  298. )
  299. action_mock.assert_called_once_with()
  300. assert len(caplog.records) == 3
  301. assert caplog.records[0].levelno == logging.DEBUG
  302. assert caplog.records[0].message == (
  303. "received topic={} payload={!r}".format(mqtt_topic, payload)
  304. )
  305. assert caplog.records[1].levelno == logging.DEBUG
  306. assert caplog.records[1].message.startswith(
  307. "executing action {} ({!r})".format("poweroff", action_mock)
  308. )
  309. assert caplog.records[2].levelno == logging.DEBUG
  310. assert caplog.records[2].message.startswith(
  311. "completed action {} ({!r})".format("poweroff", action_mock)
  312. )
  313. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  314. @pytest.mark.parametrize("payload", [b"", b"junk"])
  315. def test_mqtt_message_callback_poweroff_retained(
  316. caplog, mqtt_topic: str, payload: bytes
  317. ):
  318. message = MQTTMessage(topic=mqtt_topic.encode())
  319. message.payload = payload
  320. message.retain = True
  321. with unittest.mock.patch.object(
  322. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  323. ) as action_mock, caplog.at_level(logging.DEBUG):
  324. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  325. "poweroff"
  326. ].mqtt_message_callback(
  327. None, None, message # type: ignore
  328. )
  329. action_mock.assert_not_called()
  330. assert len(caplog.records) == 2
  331. assert caplog.records[0].levelno == logging.DEBUG
  332. assert caplog.records[0].message == (
  333. "received topic={} payload={!r}".format(mqtt_topic, payload)
  334. )
  335. assert caplog.records[1].levelno == logging.INFO
  336. assert caplog.records[1].message == "ignoring retained message"