1
0

test_mqtt.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import datetime
  18. import logging
  19. import threading
  20. import time
  21. import unittest.mock
  22. import dbus
  23. import paho.mqtt.client
  24. import pytest
  25. from paho.mqtt.client import MQTTMessage
  26. import systemctl_mqtt
  27. # pylint: disable=protected-access
  28. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  29. @pytest.mark.parametrize("mqtt_port", [1833])
  30. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  31. @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
  32. @pytest.mark.parametrize("homeassistant_node_id", ["host", "node"])
  33. def test__run(
  34. caplog,
  35. mqtt_host,
  36. mqtt_port,
  37. mqtt_topic_prefix,
  38. homeassistant_discovery_prefix,
  39. homeassistant_node_id,
  40. ):
  41. # pylint: disable=too-many-locals,too-many-arguments
  42. caplog.set_level(logging.DEBUG)
  43. with unittest.mock.patch(
  44. "socket.create_connection"
  45. ) as create_socket_mock, unittest.mock.patch(
  46. "ssl.SSLContext.wrap_socket", autospec=True
  47. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  48. "paho.mqtt.client.Client.loop_forever", autospec=True
  49. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  50. "gi.repository.GLib.MainLoop.run"
  51. ) as glib_loop_mock, unittest.mock.patch(
  52. "systemctl_mqtt._dbus.get_login_manager"
  53. ) as get_login_manager_mock:
  54. ssl_wrap_socket_mock.return_value.send = len
  55. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  56. systemctl_mqtt._run(
  57. mqtt_host=mqtt_host,
  58. mqtt_port=mqtt_port,
  59. mqtt_username=None,
  60. mqtt_password=None,
  61. mqtt_topic_prefix=mqtt_topic_prefix,
  62. homeassistant_discovery_prefix=homeassistant_discovery_prefix,
  63. homeassistant_node_id=homeassistant_node_id,
  64. poweroff_delay=datetime.timedelta(),
  65. )
  66. assert caplog.records[0].levelno == logging.INFO
  67. assert caplog.records[0].message == (
  68. "connecting to MQTT broker {}:{} (TLS enabled)".format(mqtt_host, mqtt_port)
  69. )
  70. # correct remote?
  71. assert create_socket_mock.call_count == 1
  72. create_socket_args, _ = create_socket_mock.call_args
  73. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  74. # ssl enabled?
  75. assert ssl_wrap_socket_mock.call_count == 1
  76. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  77. assert ssl_context.check_hostname is True
  78. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  79. # loop started?
  80. while threading.active_count() > 1:
  81. time.sleep(0.01)
  82. assert mqtt_loop_forever_mock.call_count == 1
  83. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  84. assert mqtt_client._tls_insecure is False
  85. # credentials
  86. assert mqtt_client._username is None
  87. assert mqtt_client._password is None
  88. # connect callback
  89. caplog.clear()
  90. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  91. with unittest.mock.patch(
  92. "paho.mqtt.client.Client.subscribe"
  93. ) as mqtt_subscribe_mock:
  94. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  95. state = mqtt_client._userdata
  96. assert (
  97. state._login_manager.connect_to_signal.call_args[1]["signal_name"]
  98. == "PrepareForShutdown"
  99. )
  100. assert sorted(mqtt_subscribe_mock.call_args_list) == [
  101. unittest.mock.call(mqtt_topic_prefix + "/lock-all-sessions"),
  102. unittest.mock.call(mqtt_topic_prefix + "/poweroff"),
  103. ]
  104. assert mqtt_client.on_message is None
  105. for suffix in ("poweroff", "lock-all-sessions"):
  106. assert ( # pylint: disable=comparison-with-callable
  107. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/" + suffix]
  108. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  109. suffix
  110. ].mqtt_message_callback
  111. )
  112. assert caplog.records[0].levelno == logging.DEBUG
  113. assert caplog.records[0].message == "connected to MQTT broker {}:{}".format(
  114. mqtt_host, mqtt_port
  115. )
  116. assert caplog.records[1].levelno == logging.DEBUG
  117. assert caplog.records[1].message == "acquired shutdown inhibitor lock"
  118. assert caplog.records[2].levelno == logging.INFO
  119. assert caplog.records[2].message == "publishing 'false' on {}".format(
  120. mqtt_topic_prefix + "/preparing-for-shutdown"
  121. )
  122. assert caplog.records[3].levelno == logging.DEBUG
  123. assert (
  124. caplog.records[3].message
  125. == "publishing home assistant config on "
  126. + homeassistant_discovery_prefix
  127. + "/binary_sensor/"
  128. + homeassistant_node_id
  129. + "/preparing-for-shutdown/config"
  130. )
  131. assert all(r.levelno == logging.INFO for r in caplog.records[4::2])
  132. assert {r.message for r in caplog.records[4::2]} == {
  133. "subscribing to {}/{}".format(mqtt_topic_prefix, s)
  134. for s in ("poweroff", "lock-all-sessions")
  135. }
  136. assert all(r.levelno == logging.DEBUG for r in caplog.records[5::2])
  137. assert {r.message for r in caplog.records[5::2]} == {
  138. "registered MQTT callback for topic {}".format(mqtt_topic_prefix + "/" + s)
  139. + " triggering {}".format(systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[s])
  140. for s in ("poweroff", "lock-all-sessions")
  141. }
  142. # dbus loop started?
  143. glib_loop_mock.assert_called_once_with()
  144. # waited for mqtt loop to stop?
  145. assert mqtt_client._thread_terminate
  146. assert mqtt_client._thread is None
  147. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  148. @pytest.mark.parametrize("mqtt_port", [1833])
  149. @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
  150. def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
  151. caplog.set_level(logging.INFO)
  152. with unittest.mock.patch(
  153. "paho.mqtt.client.Client"
  154. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  155. systemctl_mqtt._run(
  156. mqtt_host=mqtt_host,
  157. mqtt_port=mqtt_port,
  158. mqtt_disable_tls=mqtt_disable_tls,
  159. mqtt_username=None,
  160. mqtt_password=None,
  161. mqtt_topic_prefix="systemctl/hosts",
  162. homeassistant_discovery_prefix="homeassistant",
  163. homeassistant_node_id="host",
  164. poweroff_delay=datetime.timedelta(),
  165. )
  166. assert caplog.records[0].levelno == logging.INFO
  167. assert caplog.records[0].message == (
  168. "connecting to MQTT broker {}:{} (TLS {})".format(
  169. mqtt_host, mqtt_port, "disabled" if mqtt_disable_tls else "enabled"
  170. )
  171. )
  172. if mqtt_disable_tls:
  173. mqtt_client_class().tls_set.assert_not_called()
  174. else:
  175. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  176. def test__run_tls_default():
  177. with unittest.mock.patch(
  178. "paho.mqtt.client.Client"
  179. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  180. systemctl_mqtt._run(
  181. mqtt_host="mqtt-broker.local",
  182. mqtt_port=1833,
  183. # mqtt_disable_tls default,
  184. mqtt_username=None,
  185. mqtt_password=None,
  186. mqtt_topic_prefix="systemctl/hosts",
  187. homeassistant_discovery_prefix="homeassistant",
  188. homeassistant_node_id="host",
  189. poweroff_delay=datetime.timedelta(),
  190. )
  191. # enabled by default
  192. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  193. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  194. @pytest.mark.parametrize("mqtt_port", [1833])
  195. @pytest.mark.parametrize("mqtt_username", ["me"])
  196. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  197. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  198. def test__run_authentication(
  199. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  200. ):
  201. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  202. "ssl.SSLContext.wrap_socket"
  203. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  204. "paho.mqtt.client.Client.loop_forever", autospec=True
  205. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  206. "gi.repository.GLib.MainLoop.run"
  207. ), unittest.mock.patch(
  208. "systemctl_mqtt._dbus.get_login_manager"
  209. ):
  210. ssl_wrap_socket_mock.return_value.send = len
  211. systemctl_mqtt._run(
  212. mqtt_host=mqtt_host,
  213. mqtt_port=mqtt_port,
  214. mqtt_username=mqtt_username,
  215. mqtt_password=mqtt_password,
  216. mqtt_topic_prefix=mqtt_topic_prefix,
  217. homeassistant_discovery_prefix="discovery-prefix",
  218. homeassistant_node_id="node-id",
  219. poweroff_delay=datetime.timedelta(),
  220. )
  221. assert mqtt_loop_forever_mock.call_count == 1
  222. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  223. assert mqtt_client._username.decode() == mqtt_username
  224. if mqtt_password:
  225. assert mqtt_client._password.decode() == mqtt_password
  226. else:
  227. assert mqtt_client._password is None
  228. def _initialize_mqtt_client(
  229. mqtt_host, mqtt_port, mqtt_topic_prefix
  230. ) -> paho.mqtt.client.Client:
  231. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  232. "ssl.SSLContext.wrap_socket"
  233. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  234. "paho.mqtt.client.Client.loop_forever", autospec=True
  235. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  236. "gi.repository.GLib.MainLoop.run"
  237. ), unittest.mock.patch(
  238. "systemctl_mqtt._dbus.get_login_manager"
  239. ) as get_login_manager_mock:
  240. ssl_wrap_socket_mock.return_value.send = len
  241. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  242. systemctl_mqtt._run(
  243. mqtt_host=mqtt_host,
  244. mqtt_port=mqtt_port,
  245. mqtt_username=None,
  246. mqtt_password=None,
  247. mqtt_topic_prefix=mqtt_topic_prefix,
  248. homeassistant_discovery_prefix="discovery-prefix",
  249. homeassistant_node_id="node-id",
  250. poweroff_delay=datetime.timedelta(),
  251. )
  252. while threading.active_count() > 1:
  253. time.sleep(0.01)
  254. assert mqtt_loop_forever_mock.call_count == 1
  255. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  256. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  257. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  258. return mqtt_client
  259. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  260. @pytest.mark.parametrize("mqtt_port", [1833])
  261. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  262. def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  263. mqtt_client = _initialize_mqtt_client(
  264. mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
  265. )
  266. caplog.clear()
  267. caplog.set_level(logging.DEBUG)
  268. poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
  269. with unittest.mock.patch.object(
  270. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  271. ) as poweroff_trigger_mock:
  272. mqtt_client._handle_on_message(poweroff_message)
  273. poweroff_trigger_mock.assert_called_once_with(state=mqtt_client._userdata)
  274. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  275. assert caplog.records[0].message == "received topic={} payload=b''".format(
  276. poweroff_message.topic
  277. )
  278. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  279. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  280. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  281. @pytest.mark.parametrize("mqtt_port", [1833])
  282. @pytest.mark.parametrize("mqtt_password", ["secret"])
  283. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  284. with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
  285. "systemctl_mqtt._dbus.get_login_manager"
  286. ):
  287. with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
  288. systemctl_mqtt._run(
  289. mqtt_host=mqtt_host,
  290. mqtt_port=mqtt_port,
  291. mqtt_username=None,
  292. mqtt_password=mqtt_password,
  293. mqtt_topic_prefix="prefix",
  294. homeassistant_discovery_prefix="discovery-prefix",
  295. homeassistant_node_id="node-id",
  296. poweroff_delay=datetime.timedelta(),
  297. )
  298. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  299. @pytest.mark.parametrize("payload", [b"", b"junk"])
  300. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  301. message = MQTTMessage(topic=mqtt_topic.encode())
  302. message.payload = payload
  303. with unittest.mock.patch.object(
  304. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  305. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  306. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  307. "poweroff"
  308. ].mqtt_message_callback(
  309. None, "state_dummy", message # type: ignore
  310. )
  311. trigger_mock.assert_called_once_with(state="state_dummy")
  312. assert len(caplog.records) == 3
  313. assert caplog.records[0].levelno == logging.DEBUG
  314. assert caplog.records[0].message == (
  315. "received topic={} payload={!r}".format(mqtt_topic, payload)
  316. )
  317. assert caplog.records[1].levelno == logging.DEBUG
  318. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  319. assert caplog.records[2].levelno == logging.DEBUG
  320. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  321. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  322. @pytest.mark.parametrize("payload", [b"", b"junk"])
  323. def test_mqtt_message_callback_poweroff_retained(
  324. caplog, mqtt_topic: str, payload: bytes
  325. ):
  326. message = MQTTMessage(topic=mqtt_topic.encode())
  327. message.payload = payload
  328. message.retain = True
  329. with unittest.mock.patch.object(
  330. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  331. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  332. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  333. "poweroff"
  334. ].mqtt_message_callback(
  335. None, None, message # type: ignore
  336. )
  337. trigger_mock.assert_not_called()
  338. assert len(caplog.records) == 2
  339. assert caplog.records[0].levelno == logging.DEBUG
  340. assert caplog.records[0].message == (
  341. "received topic={} payload={!r}".format(mqtt_topic, payload)
  342. )
  343. assert caplog.records[1].levelno == logging.INFO
  344. assert caplog.records[1].message == "ignoring retained message"