test_mqtt.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import datetime
  18. import logging
  19. import threading
  20. import time
  21. import unittest.mock
  22. import dbus
  23. import paho.mqtt.client
  24. import pytest
  25. from paho.mqtt.client import MQTTMessage
  26. import systemctl_mqtt
  27. # pylint: disable=protected-access
  28. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  29. @pytest.mark.parametrize("mqtt_port", [1833])
  30. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  31. @pytest.mark.parametrize("homeassistant_discovery_prefix", ["homeassistant"])
  32. @pytest.mark.parametrize("homeassistant_node_id", ["host", "node"])
  33. def test__run(
  34. caplog,
  35. mqtt_host,
  36. mqtt_port,
  37. mqtt_topic_prefix,
  38. homeassistant_discovery_prefix,
  39. homeassistant_node_id,
  40. ):
  41. # pylint: disable=too-many-locals,too-many-arguments
  42. caplog.set_level(logging.DEBUG)
  43. with unittest.mock.patch(
  44. "socket.create_connection"
  45. ) as create_socket_mock, unittest.mock.patch(
  46. "ssl.SSLContext.wrap_socket", autospec=True,
  47. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  48. "paho.mqtt.client.Client.loop_forever", autospec=True,
  49. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  50. "gi.repository.GLib.MainLoop.run"
  51. ) as glib_loop_mock, unittest.mock.patch(
  52. "systemctl_mqtt._dbus.get_login_manager"
  53. ) as get_login_manager_mock:
  54. ssl_wrap_socket_mock.return_value.send = len
  55. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  56. systemctl_mqtt._run(
  57. mqtt_host=mqtt_host,
  58. mqtt_port=mqtt_port,
  59. mqtt_username=None,
  60. mqtt_password=None,
  61. mqtt_topic_prefix=mqtt_topic_prefix,
  62. homeassistant_discovery_prefix=homeassistant_discovery_prefix,
  63. homeassistant_node_id=homeassistant_node_id,
  64. poweroff_delay=datetime.timedelta(),
  65. )
  66. assert caplog.records[0].levelno == logging.INFO
  67. assert caplog.records[0].message == (
  68. "connecting to MQTT broker {}:{} (TLS enabled)".format(mqtt_host, mqtt_port)
  69. )
  70. # correct remote?
  71. assert create_socket_mock.call_count == 1
  72. create_socket_args, _ = create_socket_mock.call_args
  73. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  74. # ssl enabled?
  75. assert ssl_wrap_socket_mock.call_count == 1
  76. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  77. assert ssl_context.check_hostname is True
  78. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  79. # loop started?
  80. while threading.active_count() > 1:
  81. time.sleep(0.01)
  82. assert mqtt_loop_forever_mock.call_count == 1
  83. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  84. assert mqtt_client._tls_insecure is False
  85. # credentials
  86. assert mqtt_client._username is None
  87. assert mqtt_client._password is None
  88. # connect callback
  89. caplog.clear()
  90. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  91. with unittest.mock.patch(
  92. "paho.mqtt.client.Client.subscribe"
  93. ) as mqtt_subscribe_mock:
  94. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  95. state = mqtt_client._userdata
  96. assert (
  97. state._login_manager.connect_to_signal.call_args[1]["signal_name"]
  98. == "PrepareForShutdown"
  99. )
  100. mqtt_subscribe_mock.assert_called_once_with(mqtt_topic_prefix + "/poweroff")
  101. assert mqtt_client.on_message is None
  102. assert ( # pylint: disable=comparison-with-callable
  103. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/poweroff"]
  104. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  105. "poweroff"
  106. ].mqtt_message_callback
  107. )
  108. assert caplog.records[0].levelno == logging.DEBUG
  109. assert caplog.records[0].message == "connected to MQTT broker {}:{}".format(
  110. mqtt_host, mqtt_port
  111. )
  112. assert caplog.records[1].levelno == logging.DEBUG
  113. assert caplog.records[1].message == "acquired shutdown inhibitor lock"
  114. assert caplog.records[2].levelno == logging.INFO
  115. assert caplog.records[2].message == "publishing 'false' on {}".format(
  116. mqtt_topic_prefix + "/preparing-for-shutdown"
  117. )
  118. assert caplog.records[3].levelno == logging.DEBUG
  119. assert (
  120. caplog.records[3].message
  121. == "publishing home assistant config on "
  122. + homeassistant_discovery_prefix
  123. + "/binary_sensor/"
  124. + homeassistant_node_id
  125. + "/preparing-for-shutdown/config"
  126. )
  127. assert caplog.records[4].levelno == logging.INFO
  128. assert caplog.records[4].message == "subscribing to {}".format(
  129. mqtt_topic_prefix + "/poweroff"
  130. )
  131. assert caplog.records[5].levelno == logging.DEBUG
  132. assert caplog.records[5].message == "registered MQTT callback for topic {}".format(
  133. mqtt_topic_prefix + "/poweroff"
  134. ) + " triggering {}".format(
  135. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"]
  136. )
  137. # dbus loop started?
  138. glib_loop_mock.assert_called_once_with()
  139. # waited for mqtt loop to stop?
  140. assert mqtt_client._thread_terminate
  141. assert mqtt_client._thread is None
  142. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  143. @pytest.mark.parametrize("mqtt_port", [1833])
  144. @pytest.mark.parametrize("mqtt_disable_tls", [True, False])
  145. def test__run_tls(caplog, mqtt_host, mqtt_port, mqtt_disable_tls):
  146. caplog.set_level(logging.INFO)
  147. with unittest.mock.patch(
  148. "paho.mqtt.client.Client"
  149. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  150. systemctl_mqtt._run(
  151. mqtt_host=mqtt_host,
  152. mqtt_port=mqtt_port,
  153. mqtt_disable_tls=mqtt_disable_tls,
  154. mqtt_username=None,
  155. mqtt_password=None,
  156. mqtt_topic_prefix="systemctl/hosts",
  157. homeassistant_discovery_prefix="homeassistant",
  158. homeassistant_node_id="host",
  159. poweroff_delay=datetime.timedelta(),
  160. )
  161. assert caplog.records[0].levelno == logging.INFO
  162. assert caplog.records[0].message == (
  163. "connecting to MQTT broker {}:{} (TLS {})".format(
  164. mqtt_host, mqtt_port, "disabled" if mqtt_disable_tls else "enabled"
  165. )
  166. )
  167. if mqtt_disable_tls:
  168. mqtt_client_class().tls_set.assert_not_called()
  169. else:
  170. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  171. def test__run_tls_default():
  172. with unittest.mock.patch(
  173. "paho.mqtt.client.Client"
  174. ) as mqtt_client_class, unittest.mock.patch("gi.repository.GLib.MainLoop.run"):
  175. systemctl_mqtt._run(
  176. mqtt_host="mqtt-broker.local",
  177. mqtt_port=1833,
  178. # mqtt_disable_tls default,
  179. mqtt_username=None,
  180. mqtt_password=None,
  181. mqtt_topic_prefix="systemctl/hosts",
  182. homeassistant_discovery_prefix="homeassistant",
  183. homeassistant_node_id="host",
  184. poweroff_delay=datetime.timedelta(),
  185. )
  186. # enabled by default
  187. mqtt_client_class().tls_set.assert_called_once_with(ca_certs=None)
  188. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  189. @pytest.mark.parametrize("mqtt_port", [1833])
  190. @pytest.mark.parametrize("mqtt_username", ["me"])
  191. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  192. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  193. def test__run_authentication(
  194. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  195. ):
  196. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  197. "ssl.SSLContext.wrap_socket"
  198. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  199. "paho.mqtt.client.Client.loop_forever", autospec=True,
  200. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  201. "gi.repository.GLib.MainLoop.run"
  202. ), unittest.mock.patch(
  203. "systemctl_mqtt._dbus.get_login_manager"
  204. ):
  205. ssl_wrap_socket_mock.return_value.send = len
  206. systemctl_mqtt._run(
  207. mqtt_host=mqtt_host,
  208. mqtt_port=mqtt_port,
  209. mqtt_username=mqtt_username,
  210. mqtt_password=mqtt_password,
  211. mqtt_topic_prefix=mqtt_topic_prefix,
  212. homeassistant_discovery_prefix="discovery-prefix",
  213. homeassistant_node_id="node-id",
  214. poweroff_delay=datetime.timedelta(),
  215. )
  216. assert mqtt_loop_forever_mock.call_count == 1
  217. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  218. assert mqtt_client._username.decode() == mqtt_username
  219. if mqtt_password:
  220. assert mqtt_client._password.decode() == mqtt_password
  221. else:
  222. assert mqtt_client._password is None
  223. def _initialize_mqtt_client(
  224. mqtt_host, mqtt_port, mqtt_topic_prefix
  225. ) -> paho.mqtt.client.Client:
  226. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  227. "ssl.SSLContext.wrap_socket",
  228. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  229. "paho.mqtt.client.Client.loop_forever", autospec=True,
  230. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  231. "gi.repository.GLib.MainLoop.run"
  232. ), unittest.mock.patch(
  233. "systemctl_mqtt._dbus.get_login_manager"
  234. ) as get_login_manager_mock:
  235. ssl_wrap_socket_mock.return_value.send = len
  236. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  237. systemctl_mqtt._run(
  238. mqtt_host=mqtt_host,
  239. mqtt_port=mqtt_port,
  240. mqtt_username=None,
  241. mqtt_password=None,
  242. mqtt_topic_prefix=mqtt_topic_prefix,
  243. homeassistant_discovery_prefix="discovery-prefix",
  244. homeassistant_node_id="node-id",
  245. poweroff_delay=datetime.timedelta(),
  246. )
  247. while threading.active_count() > 1:
  248. time.sleep(0.01)
  249. assert mqtt_loop_forever_mock.call_count == 1
  250. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  251. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  252. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  253. return mqtt_client
  254. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  255. @pytest.mark.parametrize("mqtt_port", [1833])
  256. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  257. def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  258. mqtt_client = _initialize_mqtt_client(
  259. mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
  260. )
  261. caplog.clear()
  262. caplog.set_level(logging.DEBUG)
  263. poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
  264. with unittest.mock.patch.object(
  265. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  266. ) as poweroff_trigger_mock:
  267. mqtt_client._handle_on_message(poweroff_message)
  268. poweroff_trigger_mock.assert_called_once_with(state=mqtt_client._userdata)
  269. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  270. assert caplog.records[0].message == "received topic={} payload=b''".format(
  271. poweroff_message.topic
  272. )
  273. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  274. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  275. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  276. @pytest.mark.parametrize("mqtt_port", [1833])
  277. @pytest.mark.parametrize("mqtt_password", ["secret"])
  278. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  279. with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
  280. "systemctl_mqtt._dbus.get_login_manager"
  281. ):
  282. with pytest.raises(ValueError, match=r"^Missing MQTT username$"):
  283. systemctl_mqtt._run(
  284. mqtt_host=mqtt_host,
  285. mqtt_port=mqtt_port,
  286. mqtt_username=None,
  287. mqtt_password=mqtt_password,
  288. mqtt_topic_prefix="prefix",
  289. homeassistant_discovery_prefix="discovery-prefix",
  290. homeassistant_node_id="node-id",
  291. poweroff_delay=datetime.timedelta(),
  292. )
  293. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  294. @pytest.mark.parametrize("payload", [b"", b"junk"])
  295. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  296. message = MQTTMessage(topic=mqtt_topic.encode())
  297. message.payload = payload
  298. with unittest.mock.patch.object(
  299. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  300. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  301. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  302. "poweroff"
  303. ].mqtt_message_callback(
  304. None, "state_dummy", message # type: ignore
  305. )
  306. trigger_mock.assert_called_once_with(state="state_dummy")
  307. assert len(caplog.records) == 3
  308. assert caplog.records[0].levelno == logging.DEBUG
  309. assert caplog.records[0].message == (
  310. "received topic={} payload={!r}".format(mqtt_topic, payload)
  311. )
  312. assert caplog.records[1].levelno == logging.DEBUG
  313. assert caplog.records[1].message == "executing action _MQTTActionSchedulePoweroff"
  314. assert caplog.records[2].levelno == logging.DEBUG
  315. assert caplog.records[2].message == "completed action _MQTTActionSchedulePoweroff"
  316. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  317. @pytest.mark.parametrize("payload", [b"", b"junk"])
  318. def test_mqtt_message_callback_poweroff_retained(
  319. caplog, mqtt_topic: str, payload: bytes
  320. ):
  321. message = MQTTMessage(topic=mqtt_topic.encode())
  322. message.payload = payload
  323. message.retain = True
  324. with unittest.mock.patch.object(
  325. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "trigger"
  326. ) as trigger_mock, caplog.at_level(logging.DEBUG):
  327. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  328. "poweroff"
  329. ].mqtt_message_callback(
  330. None, None, message # type: ignore
  331. )
  332. trigger_mock.assert_not_called()
  333. assert len(caplog.records) == 2
  334. assert caplog.records[0].levelno == logging.DEBUG
  335. assert caplog.records[0].message == (
  336. "received topic={} payload={!r}".format(mqtt_topic, payload)
  337. )
  338. assert caplog.records[1].levelno == logging.INFO
  339. assert caplog.records[1].message == "ignoring retained message"