1
0

test_mqtt.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import datetime
  18. import logging
  19. import threading
  20. import time
  21. import unittest.mock
  22. import dbus
  23. import paho.mqtt.client
  24. import pytest
  25. from paho.mqtt.client import MQTTMessage
  26. import systemctl_mqtt
  27. # pylint: disable=protected-access
  28. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  29. @pytest.mark.parametrize("mqtt_port", [1833])
  30. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  31. @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
  32. @pytest.mark.parametrize("homeassistant_node_id", ["host", "node"])
  33. def test__run(
  34. caplog,
  35. mqtt_host,
  36. mqtt_port,
  37. mqtt_topic_prefix,
  38. homeassistant_discovery_prefix,
  39. homeassistant_node_id,
  40. ):
  41. # pylint: disable=too-many-locals,too-many-arguments
  42. caplog.set_level(logging.DEBUG)
  43. with unittest.mock.patch(
  44. "socket.create_connection"
  45. ) as create_socket_mock, unittest.mock.patch(
  46. "ssl.SSLContext.wrap_socket", autospec=True
  47. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  48. "paho.mqtt.client.Client.loop_forever", autospec=True
  49. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  50. "gi.repository.GLib.MainLoop.run"
  51. ) as glib_loop_mock, unittest.mock.patch(
  52. "systemctl_mqtt._dbus.get_login_manager"
  53. ) as get_login_manager_mock:
  54. ssl_wrap_socket_mock.return_value.send = len
  55. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  56. systemctl_mqtt._run(
  57. mqtt_host=mqtt_host,
  58. mqtt_port=mqtt_port,
  59. mqtt_username=None,
  60. mqtt_password=None,
  61. mqtt_topic_prefix=mqtt_topic_prefix,
  62. homeassistant_discovery_prefix=homeassistant_discovery_prefix,
  63. homeassistant_node_id=homeassistant_node_id,
  64. poweroff_delay=datetime.timedelta(),
  65. )
  66. assert caplog.records[0].levelno == logging.INFO
  67. assert caplog.records[0].message == (
  68. f"connecting to MQTT broker {mqtt_host}:{mqtt_port} (TLS enabled)"
  69. )
  70. # correct remote?
  71. create_socket_mock.assert_called_once()
  72. create_socket_args, _ = create_socket_mock.call_args
  73. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  74. # ssl enabled?
  75. ssl_wrap_socket_mock.assert_called_once()
  76. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  77. assert ssl_context.check_hostname is True
  78. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  79. # loop started?
  80. while threading.active_count() > 1:
  81. time.sleep(0.01)
  82. mqtt_loop_forever_mock.assert_called_once()
  83. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  84. assert mqtt_client._tls_insecure is False
  85. # credentials
  86. assert mqtt_client._username is None
  87. assert mqtt_client._password is None
  88. # connect callback
  89. caplog.clear()
  90. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  91. with unittest.mock.patch(
  92. "paho.mqtt.client.Client.subscribe"
  93. ) as mqtt_subscribe_mock:
  94. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  95. state = mqtt_client._userdata
  96. assert (
  97. state._login_manager.connect_to_signal.call_args[1]["signal_name"]
  98. == "PrepareForShutdown"
  99. )
  100. assert sorted(mqtt_subscribe_mock.call_args_list) == [
  101. unittest.mock.call(mqtt_topic_prefix + "/lock-all-sessions"),
  102. unittest.mock.call(mqtt_topic_prefix + "/poweroff"),
  103. ]
  104. assert mqtt_client.on_message is None
  105. for suffix in ("poweroff", "lock-all-sessions"):
  106. assert ( # pylint: disable=comparison-with-callable
  107. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/" + suffix]
  108. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  109. suffix
  110. ].mqtt_message_callback
  111. )
  112. assert caplog.records[0].levelno == logging.DEBUG
  113. assert (
  114. caplog.records[0].message == f"connected to MQTT broker {mqtt_host}:{mqtt_port}"
  115. )
  116. assert caplog.records[1].levelno == logging.DEBUG
  117. assert caplog.records[1].message == "acquired shutdown inhibitor lock"
  118. assert caplog.records[2].levelno == logging.INFO
  119. assert (
  120. caplog.records[2].message
  121. == f"publishing 'false' on {mqtt_topic_prefix}/preparing-for-shutdown"
  122. )
  123. assert caplog.records[3].levelno == logging.DEBUG
  124. assert (
  125. caplog.records[3].message
  126. == "publishing home assistant config on "
  127. + homeassistant_discovery_prefix
  128. + "/binary_sensor/"
  129. + homeassistant_node_id
  130. + "/preparing-for-shutdown/config"
  131. )
  132. assert all(r.levelno == logging.INFO for r in caplog.records[4::2])
  133. assert {r.message for r in caplog.records[4::2]} == {
  134. f"subscribing to {mqtt_topic_prefix}/{s}"
  135. for s in ("poweroff", "lock-all-sessions")
  136. }
  137. assert all(r.levelno == logging.DEBUG for r in caplog.records[5::2])
  138. assert {r.message for r in caplog.records[5::2]} == {
  139. f"registered MQTT callback for topic {mqtt_topic_prefix}/{s}"
  140. f" triggering {systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[s]}"
  141. for s in ("poweroff", "lock-all-sessions")
  142. }
  143. # dbus loop started?
  144. glib_loop_mock.assert_called_once_with()
  145. # waited for mqtt loop to stop?
  146. assert mqtt_client._thread_terminate
  147. assert mqtt_client._thread is None
  148. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  149. @pytest.mark.parametrize("mqtt_port", [1833])
  150. @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
  151. def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
  152. caplog.set_level(logging.INFO)
  153. with unittest.mock.patch(
  154. "paho.mqtt.client.Client"
  155. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  156. systemctl_mqtt._run(
  157. mqtt_host=mqtt_host,
  158. mqtt_port=mqtt_port,
  159. mqtt_disable_tls=mqtt_disable_tls,
  160. mqtt_username=None,
  161. mqtt_password=None,
  162. mqtt_topic_prefix="systemctl/hosts",
  163. homeassistant_discovery_prefix="homeassistant",
  164. homeassistant_node_id="host",
  165. poweroff_delay=datetime.timedelta(),
  166. )
  167. assert caplog.records[0].levelno == logging.INFO
  168. assert caplog.records[0].message == (
  169. f"connecting to MQTT broker {mqtt_host}:{mqtt_port}"
  170. f" (TLS {'disabled' if mqtt_disable_tls else 'enabled'})"
  171. )
  172. if mqtt_disable_tls:
  173. mqtt_client_class().tls_set.assert_not_called()
  174. else:
  175. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  176. def test__run_tls_default():
  177. with unittest.mock.patch(
  178. "paho.mqtt.client.Client"
  179. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  180. systemctl_mqtt._run(
  181. mqtt_host="mqtt-broker.local",
  182. mqtt_port=1833,
  183. # mqtt_disable_tls default,
  184. mqtt_username=None,
  185. mqtt_password=None,
  186. mqtt_topic_prefix="systemctl/hosts",
  187. homeassistant_discovery_prefix="homeassistant",
  188. homeassistant_node_id="host",
  189. poweroff_delay=datetime.timedelta(),
  190. )
  191. # enabled by default
  192. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  193. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  194. @pytest.mark.parametrize("mqtt_port", [1833])
  195. @pytest.mark.parametrize("mqtt_username", ["me"])
  196. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  197. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  198. def test__run_authentication(
  199. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  200. ):
  201. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  202. "ssl.SSLContext.wrap_socket"
  203. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  204. "paho.mqtt.client.Client.loop_forever", autospec=True
  205. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  206. "gi.repository.GLib.MainLoop.run"
  207. ), unittest.mock.patch(
  208. "systemctl_mqtt._dbus.get_login_manager"
  209. ):
  210. ssl_wrap_socket_mock.return_value.send = len
  211. systemctl_mqtt._run(
  212. mqtt_host=mqtt_host,
  213. mqtt_port=mqtt_port,
  214. mqtt_username=mqtt_username,
  215. mqtt_password=mqtt_password,
  216. mqtt_topic_prefix=mqtt_topic_prefix,
  217. homeassistant_discovery_prefix="discovery-prefix",
  218. homeassistant_node_id="node-id",
  219. poweroff_delay=datetime.timedelta(),
  220. )
  221. mqtt_loop_forever_mock.assert_called_once()
  222. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  223. assert mqtt_client._username.decode() == mqtt_username
  224. if mqtt_password:
  225. assert mqtt_client._password.decode() == mqtt_password
  226. else:
  227. assert mqtt_client._password is None
  228. def _initialize_mqtt_client(
  229. mqtt_host, mqtt_port, mqtt_topic_prefix
  230. ) -> paho.mqtt.client.Client:
  231. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  232. "ssl.SSLContext.wrap_socket"
  233. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  234. "paho.mqtt.client.Client.loop_forever", autospec=True
  235. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  236. "gi.repository.GLib.MainLoop.run"
  237. ), unittest.mock.patch(
  238. "systemctl_mqtt._dbus.get_login_manager"
  239. ) as get_login_manager_mock:
  240. ssl_wrap_socket_mock.return_value.send = len
  241. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  242. systemctl_mqtt._run(
  243. mqtt_host=mqtt_host,
  244. mqtt_port=mqtt_port,
  245. mqtt_username=None,
  246. mqtt_password=None,
  247. mqtt_topic_prefix=mqtt_topic_prefix,
  248. homeassistant_discovery_prefix="discovery-prefix",
  249. homeassistant_node_id="node-id",
  250. poweroff_delay=datetime.timedelta(),
  251. )
  252. while threading.active_count() > 1:
  253. time.sleep(0.01)
  254. mqtt_loop_forever_mock.assert_called_once()
  255. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  256. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  257. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  258. return mqtt_client
  259. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  260. @pytest.mark.parametrize("mqtt_port", [1833])
  261. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  262. def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  263. mqtt_client = _initialize_mqtt_client(
  264. mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
  265. )
  266. caplog.clear()
  267. caplog.set_level(logging.DEBUG)
  268. poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
  269. with unittest.mock.patch.object(
  270. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  271. ) as poweroff_trigger_mock:
  272. mqtt_client._handle_on_message(poweroff_message)
  273. poweroff_trigger_mock.assert_called_once_with(state=mqtt_client._userdata)
  274. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  275. assert (
  276. caplog.records[0].message
  277. == f"received topic={poweroff_message.topic} payload=b''"
  278. )
  279. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  280. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  281. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  282. @pytest.mark.parametrize("mqtt_port", [1833])
  283. @pytest.mark.parametrize("mqtt_password", ["secret"])
  284. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  285. with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
  286. "systemctl_mqtt._dbus.get_login_manager"
  287. ):
  288. with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
  289. systemctl_mqtt._run(
  290. mqtt_host=mqtt_host,
  291. mqtt_port=mqtt_port,
  292. mqtt_username=None,
  293. mqtt_password=mqtt_password,
  294. mqtt_topic_prefix="prefix",
  295. homeassistant_discovery_prefix="discovery-prefix",
  296. homeassistant_node_id="node-id",
  297. poweroff_delay=datetime.timedelta(),
  298. )
  299. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  300. @pytest.mark.parametrize("payload", [b"", b"junk"])
  301. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  302. message = MQTTMessage(topic=mqtt_topic.encode())
  303. message.payload = payload
  304. with unittest.mock.patch.object(
  305. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  306. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  307. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  308. "poweroff"
  309. ].mqtt_message_callback(
  310. None, "state_dummy", message # type: ignore
  311. )
  312. trigger_mock.assert_called_once_with(state="state_dummy")
  313. assert len(caplog.records) == 3
  314. assert caplog.records[0].levelno == logging.DEBUG
  315. assert caplog.records[0].message == (
  316. f"received topic={mqtt_topic} payload={payload!r}"
  317. )
  318. assert caplog.records[1].levelno == logging.DEBUG
  319. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  320. assert caplog.records[2].levelno == logging.DEBUG
  321. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  322. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  323. @pytest.mark.parametrize("payload", [b"", b"junk"])
  324. def test_mqtt_message_callback_poweroff_retained(
  325. caplog, mqtt_topic: str, payload: bytes
  326. ):
  327. message = MQTTMessage(topic=mqtt_topic.encode())
  328. message.payload = payload
  329. message.retain = True
  330. with unittest.mock.patch.object(
  331. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  332. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  333. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  334. "poweroff"
  335. ].mqtt_message_callback(
  336. None, None, message # type: ignore
  337. )
  338. trigger_mock.assert_not_called()
  339. assert len(caplog.records) == 2
  340. assert caplog.records[0].levelno == logging.DEBUG
  341. assert caplog.records[0].message == (
  342. f"received topic={mqtt_topic} payload={payload!r}"
  343. )
  344. assert caplog.records[1].levelno == logging.INFO
  345. assert caplog.records[1].message == "ignoring retained message"