test_mqtt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # systemctl-mqtt - MQTT client triggering & reporting shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import logging
  18. import threading
  19. import time
  20. import unittest.mock
  21. import dbus
  22. import paho.mqtt.client
  23. import pytest
  24. from paho.mqtt.client import MQTTMessage
  25. import systemctl_mqtt
  26. # pylint: disable=protected-access
  27. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  28. @pytest.mark.parametrize("mqtt_port", [1833])
  29. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  30. def test__run(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  31. caplog.set_level(logging.DEBUG)
  32. with unittest.mock.patch(
  33. "socket.create_connection"
  34. ) as create_socket_mock, unittest.mock.patch(
  35. "ssl.SSLContext.wrap_socket", autospec=True,
  36. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  37. "paho.mqtt.client.Client.loop_forever", autospec=True,
  38. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  39. "gi.repository.GLib.MainLoop.run"
  40. ) as glib_loop_mock, unittest.mock.patch(
  41. "systemctl_mqtt._get_login_manager"
  42. ) as get_login_manager_mock:
  43. ssl_wrap_socket_mock.return_value.send = len
  44. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  45. systemctl_mqtt._run(
  46. mqtt_host=mqtt_host,
  47. mqtt_port=mqtt_port,
  48. mqtt_username=None,
  49. mqtt_password=None,
  50. mqtt_topic_prefix=mqtt_topic_prefix,
  51. )
  52. assert caplog.records[0].levelno == logging.INFO
  53. assert caplog.records[0].message == "connecting to MQTT broker {}:{}".format(
  54. mqtt_host, mqtt_port
  55. )
  56. # correct remote?
  57. assert create_socket_mock.call_count == 1
  58. create_socket_args, _ = create_socket_mock.call_args
  59. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  60. # ssl enabled?
  61. assert ssl_wrap_socket_mock.call_count == 1
  62. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  63. assert ssl_context.check_hostname is True
  64. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  65. # loop started?
  66. while threading.active_count() > 1:
  67. time.sleep(0.01)
  68. assert mqtt_loop_forever_mock.call_count == 1
  69. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  70. assert mqtt_client._tls_insecure is False
  71. # credentials
  72. assert mqtt_client._username is None
  73. assert mqtt_client._password is None
  74. # connect callback
  75. caplog.clear()
  76. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  77. with unittest.mock.patch(
  78. "paho.mqtt.client.Client.subscribe"
  79. ) as mqtt_subscribe_mock:
  80. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  81. state = mqtt_client._userdata
  82. assert (
  83. state._login_manager.connect_to_signal.call_args[1]["signal_name"]
  84. == "PrepareForShutdown"
  85. )
  86. mqtt_subscribe_mock.assert_called_once_with(mqtt_topic_prefix + "/poweroff")
  87. assert mqtt_client.on_message is None
  88. assert ( # pylint: disable=comparison-with-callable
  89. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/poweroff"]
  90. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  91. "poweroff"
  92. ].mqtt_message_callback
  93. )
  94. assert caplog.records[0].levelno == logging.DEBUG
  95. assert caplog.records[0].message == "connected to MQTT broker {}:{}".format(
  96. mqtt_host, mqtt_port
  97. )
  98. assert caplog.records[1].levelno == logging.DEBUG
  99. assert caplog.records[1].message == "acquired shutdown inhibitor lock"
  100. assert caplog.records[2].levelno == logging.INFO
  101. assert caplog.records[2].message == "publishing 'false' on {}".format(
  102. mqtt_topic_prefix + "/preparing-for-shutdown"
  103. )
  104. assert caplog.records[3].levelno == logging.INFO
  105. assert caplog.records[3].message == "subscribing to {}".format(
  106. mqtt_topic_prefix + "/poweroff"
  107. )
  108. assert caplog.records[4].levelno == logging.DEBUG
  109. assert caplog.records[4].message == "registered MQTT callback for topic {}".format(
  110. mqtt_topic_prefix + "/poweroff"
  111. ) + " triggering {}".format(
  112. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"].action
  113. )
  114. # dbus loop started?
  115. glib_loop_mock.assert_called_once_with()
  116. # waited for mqtt loop to stop?
  117. assert mqtt_client._thread_terminate
  118. assert mqtt_client._thread is None
  119. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  120. @pytest.mark.parametrize("mqtt_port", [1833])
  121. @pytest.mark.parametrize("mqtt_username", ["me"])
  122. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  123. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  124. def test__run_authentication(
  125. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  126. ):
  127. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  128. "ssl.SSLContext.wrap_socket"
  129. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  130. "paho.mqtt.client.Client.loop_forever", autospec=True,
  131. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  132. "gi.repository.GLib.MainLoop.run"
  133. ), unittest.mock.patch(
  134. "systemctl_mqtt._get_login_manager"
  135. ):
  136. ssl_wrap_socket_mock.return_value.send = len
  137. systemctl_mqtt._run(
  138. mqtt_host=mqtt_host,
  139. mqtt_port=mqtt_port,
  140. mqtt_username=mqtt_username,
  141. mqtt_password=mqtt_password,
  142. mqtt_topic_prefix=mqtt_topic_prefix,
  143. )
  144. assert mqtt_loop_forever_mock.call_count == 1
  145. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  146. assert mqtt_client._username.decode() == mqtt_username
  147. if mqtt_password:
  148. assert mqtt_client._password.decode() == mqtt_password
  149. else:
  150. assert mqtt_client._password is None
  151. def _initialize_mqtt_client(
  152. mqtt_host, mqtt_port, mqtt_topic_prefix
  153. ) -> paho.mqtt.client.Client:
  154. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  155. "ssl.SSLContext.wrap_socket",
  156. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  157. "paho.mqtt.client.Client.loop_forever", autospec=True,
  158. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  159. "gi.repository.GLib.MainLoop.run"
  160. ), unittest.mock.patch(
  161. "systemctl_mqtt._get_login_manager"
  162. ) as get_login_manager_mock:
  163. ssl_wrap_socket_mock.return_value.send = len
  164. get_login_manager_mock.return_value.Get.return_value = dbus.Boolean(False)
  165. systemctl_mqtt._run(
  166. mqtt_host=mqtt_host,
  167. mqtt_port=mqtt_port,
  168. mqtt_username=None,
  169. mqtt_password=None,
  170. mqtt_topic_prefix=mqtt_topic_prefix,
  171. )
  172. while threading.active_count() > 1:
  173. time.sleep(0.01)
  174. assert mqtt_loop_forever_mock.call_count == 1
  175. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  176. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  177. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  178. return mqtt_client
  179. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  180. @pytest.mark.parametrize("mqtt_port", [1833])
  181. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  182. def test__client_handle_message(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  183. mqtt_client = _initialize_mqtt_client(
  184. mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_topic_prefix=mqtt_topic_prefix
  185. )
  186. caplog.clear()
  187. caplog.set_level(logging.DEBUG)
  188. poweroff_message = MQTTMessage(topic=mqtt_topic_prefix.encode() + b"/poweroff")
  189. with unittest.mock.patch.object(
  190. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  191. ) as poweroff_action_mock:
  192. mqtt_client._handle_on_message(poweroff_message)
  193. poweroff_action_mock.assert_called_once_with()
  194. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  195. assert caplog.records[0].message == "received topic={} payload=b''".format(
  196. poweroff_message.topic
  197. )
  198. assert caplog.records[1].message.startswith("executing action poweroff")
  199. assert caplog.records[2].message.startswith("completed action poweroff")
  200. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  201. @pytest.mark.parametrize("mqtt_port", [1833])
  202. @pytest.mark.parametrize("mqtt_password", ["secret"])
  203. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  204. with unittest.mock.patch("paho.mqtt.client.Client"), unittest.mock.patch(
  205. "systemctl_mqtt._get_login_manager"
  206. ):
  207. with pytest.raises(ValueError):
  208. systemctl_mqtt._run(
  209. mqtt_host=mqtt_host,
  210. mqtt_port=mqtt_port,
  211. mqtt_username=None,
  212. mqtt_password=mqtt_password,
  213. mqtt_topic_prefix="prefix",
  214. )
  215. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  216. @pytest.mark.parametrize("payload", [b"", b"junk"])
  217. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  218. message = MQTTMessage(topic=mqtt_topic.encode())
  219. message.payload = payload
  220. with unittest.mock.patch.object(
  221. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  222. ) as action_mock, caplog.at_level(logging.DEBUG):
  223. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  224. "poweroff"
  225. ].mqtt_message_callback(
  226. None, None, message # type: ignore
  227. )
  228. action_mock.assert_called_once_with()
  229. assert len(caplog.records) == 3
  230. assert caplog.records[0].levelno == logging.DEBUG
  231. assert caplog.records[0].message == (
  232. "received topic={} payload={!r}".format(mqtt_topic, payload)
  233. )
  234. assert caplog.records[1].levelno == logging.DEBUG
  235. assert caplog.records[1].message.startswith(
  236. "executing action {} ({!r})".format("poweroff", action_mock)
  237. )
  238. assert caplog.records[2].levelno == logging.DEBUG
  239. assert caplog.records[2].message.startswith(
  240. "completed action {} ({!r})".format("poweroff", action_mock)
  241. )
  242. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  243. @pytest.mark.parametrize("payload", [b"", b"junk"])
  244. def test_mqtt_message_callback_poweroff_retained(
  245. caplog, mqtt_topic: str, payload: bytes
  246. ):
  247. message = MQTTMessage(topic=mqtt_topic.encode())
  248. message.payload = payload
  249. message.retain = True
  250. with unittest.mock.patch.object(
  251. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  252. ) as action_mock, caplog.at_level(logging.DEBUG):
  253. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  254. "poweroff"
  255. ].mqtt_message_callback(
  256. None, None, message # type: ignore
  257. )
  258. action_mock.assert_not_called()
  259. assert len(caplog.records) == 2
  260. assert caplog.records[0].levelno == logging.DEBUG
  261. assert caplog.records[0].message == (
  262. "received topic={} payload={!r}".format(mqtt_topic, payload)
  263. )
  264. assert caplog.records[1].levelno == logging.INFO
  265. assert caplog.records[1].message == "ignoring retained message"