test_mqtt.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import datetime
  18. import logging
  19. import threading
  20. import time
  21. import unittest.mock
  22. import dbus
  23. import paho.mqtt.client
  24. import pytest
  25. from paho.mqtt.client import MQTTMessage
  26. import systemctl_mqtt
  27. # pylint: disable=protected-access
  28. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  29. @pytest.mark.parametrize("mqtt_port", [1833])
  30. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  31. @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
  32. @pytest.mark.parametrize("homeassistant_node_id", ["host", "node"])
  33. def test__run(
  34. caplog,
  35. mqtt_host,
  36. mqtt_port,
  37. mqtt_topic_prefix,
  38. homeassistant_discovery_prefix,
  39. homeassistant_node_id,
  40. ):
  41. # pylint: disable=too-many-locals,too-many-arguments
  42. caplog.set_level(logging.DEBUG)
  43. with unittest.mock.patch(
  44. "socket.create_connection"
  45. ) as create_socket_mock, unittest.mock.patch(
  46. "ssl.SSLContext.wrap_socket", autospec=True,
  47. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  48. "paho.mqtt.client.Client.loop_forever", autospec=True,
  49. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  50. "gi.repository.GLib.MainLoop.run"
  51. ) as glib_loop_mock, unittest.mock.patch(
  52. "systemctl_mqtt._dbus.get_login_manager"
  53. ) as get_login_manager_mock:
  54. ssl_wrap_socket_mock.return_value.send = len
  55. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  56. systemctl_mqtt._run(
  57. mqtt_host=mqtt_host,
  58. mqtt_port=mqtt_port,
  59. mqtt_username=None,
  60. mqtt_password=None,
  61. mqtt_topic_prefix=mqtt_topic_prefix,
  62. homeassistant_discovery_prefix=homeassistant_discovery_prefix,
  63. homeassistant_node_id=homeassistant_node_id,
  64. poweroff_delay=datetime.timedelta(),
  65. )
  66. assert caplog.records[0].levelno == logging.INFO
  67. assert caplog.records[0].message == (
  68. "connecting to MQTT broker {}:{} (TLS enabled)".format(mqtt_host, mqtt_port)
  69. )
  70. # correct remote?
  71. assert create_socket_mock.call_count == 1
  72. create_socket_args, _ = create_socket_mock.call_args
  73. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  74. # ssl enabled?
  75. assert ssl_wrap_socket_mock.call_count == 1
  76. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  77. assert ssl_context.check_hostname is True
  78. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  79. # loop started?
  80. while threading.active_count() > 1:
  81. time.sleep(0.01)
  82. assert mqtt_loop_forever_mock.call_count == 1
  83. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  84. assert mqtt_client._tls_insecure is False
  85. # credentials
  86. assert mqtt_client._username is None
  87. assert mqtt_client._password is None
  88. # connect callback
  89. caplog.clear()
  90. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  91. with unittest.mock.patch(
  92. "paho.mqtt.client.Client.subscribe"
  93. ) as mqtt_subscribe_mock:
  94. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  95. state = mqtt_client._userdata
  96. assert (
  97. state._login_manager.connect_to_signal.call_args[1]["signal_name"]
  98. == "PrepareForShutdown"
  99. )
  100. assert mqtt_subscribe_mock.call_args_list == [
  101. unittest.mock.call(mqtt_topic_prefix + "/poweroff"),
  102. unittest.mock.call(mqtt_topic_prefix + "/lock-all-sessions"),
  103. ]
  104. assert mqtt_client.on_message is None
  105. for suffix in ("poweroff", "lock-all-sessions"):
  106. assert ( # pylint: disable=comparison-with-callable
  107. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/" + suffix]
  108. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  109. suffix
  110. ].mqtt_message_callback
  111. )
  112. assert caplog.records[0].levelno == logging.DEBUG
  113. assert caplog.records[0].message == "connected to MQTT broker {}:{}".format(
  114. mqtt_host, mqtt_port
  115. )
  116. assert caplog.records[1].levelno == logging.DEBUG
  117. assert caplog.records[1].message == "acquired shutdown inhibitor lock"
  118. assert caplog.records[2].levelno == logging.INFO
  119. assert caplog.records[2].message == "publishing 'false' on {}".format(
  120. mqtt_topic_prefix + "/preparing-for-shutdown"
  121. )
  122. assert caplog.records[3].levelno == logging.DEBUG
  123. assert (
  124. caplog.records[3].message
  125. == "publishing home assistant config on "
  126. + homeassistant_discovery_prefix
  127. + "/binary_sensor/"
  128. + homeassistant_node_id
  129. + "/preparing-for-shutdown/config"
  130. )
  131. assert caplog.records[4].levelno == logging.INFO
  132. assert caplog.records[4].message == "subscribing to {}".format(
  133. mqtt_topic_prefix + "/poweroff"
  134. )
  135. assert caplog.records[5].levelno == logging.DEBUG
  136. assert caplog.records[5].message == "registered MQTT callback for topic {}".format(
  137. mqtt_topic_prefix + "/poweroff"
  138. ) + " triggering {}".format(
  139. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"]
  140. )
  141. # dbus loop started?
  142. glib_loop_mock.assert_called_once_with()
  143. # waited for mqtt loop to stop?
  144. assert mqtt_client._thread_terminate
  145. assert mqtt_client._thread is None
  146. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  147. @pytest.mark.parametrize("mqtt_port", [1833])
  148. @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
  149. def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
  150. caplog.set_level(logging.INFO)
  151. with unittest.mock.patch(
  152. "paho.mqtt.client.Client"
  153. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  154. systemctl_mqtt._run(
  155. mqtt_host=mqtt_host,
  156. mqtt_port=mqtt_port,
  157. mqtt_disable_tls=mqtt_disable_tls,
  158. mqtt_username=None,
  159. mqtt_password=None,
  160. mqtt_topic_prefix="systemctl/hosts",
  161. homeassistant_discovery_prefix="homeassistant",
  162. homeassistant_node_id="host",
  163. poweroff_delay=datetime.timedelta(),
  164. )
  165. assert caplog.records[0].levelno == logging.INFO
  166. assert caplog.records[0].message == (
  167. "connecting to MQTT broker {}:{} (TLS {})".format(
  168. mqtt_host, mqtt_port, "disabled" if mqtt_disable_tls else "enabled"
  169. )
  170. )
  171. if mqtt_disable_tls:
  172. mqtt_client_class().tls_set.assert_not_called()
  173. else:
  174. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  175. def test__run_tls_default():
  176. with unittest.mock.patch(
  177. "paho.mqtt.client.Client"
  178. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  179. systemctl_mqtt._run(
  180. mqtt_host="mqtt-broker.local",
  181. mqtt_port=1833,
  182. # mqtt_disable_tls default,
  183. mqtt_username=None,
  184. mqtt_password=None,
  185. mqtt_topic_prefix="systemctl/hosts",
  186. homeassistant_discovery_prefix="homeassistant",
  187. homeassistant_node_id="host",
  188. poweroff_delay=datetime.timedelta(),
  189. )
  190. # enabled by default
  191. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  192. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  193. @pytest.mark.parametrize("mqtt_port", [1833])
  194. @pytest.mark.parametrize("mqtt_username", ["me"])
  195. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  196. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  197. def test__run_authentication(
  198. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  199. ):
  200. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  201. "ssl.SSLContext.wrap_socket"
  202. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  203. "paho.mqtt.client.Client.loop_forever", autospec=True,
  204. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  205. "gi.repository.GLib.MainLoop.run"
  206. ), unittest.mock.patch(
  207. "systemctl_mqtt._dbus.get_login_manager"
  208. ):
  209. ssl_wrap_socket_mock.return_value.send = len
  210. systemctl_mqtt._run(
  211. mqtt_host=mqtt_host,
  212. mqtt_port=mqtt_port,
  213. mqtt_username=mqtt_username,
  214. mqtt_password=mqtt_password,
  215. mqtt_topic_prefix=mqtt_topic_prefix,
  216. homeassistant_discovery_prefix="discovery-prefix",
  217. homeassistant_node_id="node-id",
  218. poweroff_delay=datetime.timedelta(),
  219. )
  220. assert mqtt_loop_forever_mock.call_count == 1
  221. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  222. assert mqtt_client._username.decode() == mqtt_username
  223. if mqtt_password:
  224. assert mqtt_client._password.decode() == mqtt_password
  225. else:
  226. assert mqtt_client._password is None
  227. def _initialize_mqtt_client(
  228. mqtt_host, mqtt_port, mqtt_topic_prefix
  229. ) -> paho.mqtt.client.Client:
  230. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  231. "ssl.SSLContext.wrap_socket",
  232. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  233. "paho.mqtt.client.Client.loop_forever", autospec=True,
  234. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  235. "gi.repository.GLib.MainLoop.run"
  236. ), unittest.mock.patch(
  237. "systemctl_mqtt._dbus.get_login_manager"
  238. ) as get_login_manager_mock:
  239. ssl_wrap_socket_mock.return_value.send = len
  240. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  241. systemctl_mqtt._run(
  242. mqtt_host=mqtt_host,
  243. mqtt_port=mqtt_port,
  244. mqtt_username=None,
  245. mqtt_password=None,
  246. mqtt_topic_prefix=mqtt_topic_prefix,
  247. homeassistant_discovery_prefix="discovery-prefix",
  248. homeassistant_node_id="node-id",
  249. poweroff_delay=datetime.timedelta(),
  250. )
  251. while threading.active_count() > 1:
  252. time.sleep(0.01)
  253. assert mqtt_loop_forever_mock.call_count == 1
  254. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  255. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  256. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  257. return mqtt_client
  258. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  259. @pytest.mark.parametrize("mqtt_port", [1833])
  260. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  261. def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  262. mqtt_client = _initialize_mqtt_client(
  263. mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
  264. )
  265. caplog.clear()
  266. caplog.set_level(logging.DEBUG)
  267. poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
  268. with unittest.mock.patch.object(
  269. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  270. ) as poweroff_trigger_mock:
  271. mqtt_client._handle_on_message(poweroff_message)
  272. poweroff_trigger_mock.assert_called_once_with(state=mqtt_client._userdata)
  273. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  274. assert caplog.records[0].message == "received topic={} payload=b''".format(
  275. poweroff_message.topic
  276. )
  277. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  278. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  279. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  280. @pytest.mark.parametrize("mqtt_port", [1833])
  281. @pytest.mark.parametrize("mqtt_password", ["secret"])
  282. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  283. with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
  284. "systemctl_mqtt._dbus.get_login_manager"
  285. ):
  286. with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
  287. systemctl_mqtt._run(
  288. mqtt_host=mqtt_host,
  289. mqtt_port=mqtt_port,
  290. mqtt_username=None,
  291. mqtt_password=mqtt_password,
  292. mqtt_topic_prefix="prefix",
  293. homeassistant_discovery_prefix="discovery-prefix",
  294. homeassistant_node_id="node-id",
  295. poweroff_delay=datetime.timedelta(),
  296. )
  297. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  298. @pytest.mark.parametrize("payload", [b"", b"junk"])
  299. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  300. message = MQTTMessage(topic=mqtt_topic.encode())
  301. message.payload = payload
  302. with unittest.mock.patch.object(
  303. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  304. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  305. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  306. "poweroff"
  307. ].mqtt_message_callback(
  308. None, "state_dummy", message # type: ignore
  309. )
  310. trigger_mock.assert_called_once_with(state="state_dummy")
  311. assert len(caplog.records) == 3
  312. assert caplog.records[0].levelno == logging.DEBUG
  313. assert caplog.records[0].message == (
  314. "received topic={} payload={!r}".format(mqtt_topic, payload)
  315. )
  316. assert caplog.records[1].levelno == logging.DEBUG
  317. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  318. assert caplog.records[2].levelno == logging.DEBUG
  319. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  320. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  321. @pytest.mark.parametrize("payload", [b"", b"junk"])
  322. def test_mqtt_message_callback_poweroff_retained(
  323. caplog, mqtt_topic: str, payload: bytes
  324. ):
  325. message = MQTTMessage(topic=mqtt_topic.encode())
  326. message.payload = payload
  327. message.retain = True
  328. with unittest.mock.patch.object(
  329. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  330. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  331. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  332. "poweroff"
  333. ].mqtt_message_callback(
  334. None, None, message # type: ignore
  335. )
  336. trigger_mock.assert_not_called()
  337. assert len(caplog.records) == 2
  338. assert caplog.records[0].levelno == logging.DEBUG
  339. assert caplog.records[0].message == (
  340. "received topic={} payload={!r}".format(mqtt_topic, payload)
  341. )
  342. assert caplog.records[1].levelno == logging.INFO
  343. assert caplog.records[1].message == "ignoring retained message"