test_mqtt.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # systemctl-mqtt - MQTT client triggering shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import logging
  18. import threading
  19. import time
  20. import unittest.mock
  21. import paho.mqtt.client
  22. import pytest
  23. from paho.mqtt.client import MQTTMessage
  24. import systemctl_mqtt
  25. # pylint: disable=protected-access
  26. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  27. @pytest.mark.parametrize("mqtt_port", [1833])
  28. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  29. def test__run(caplog, mqtt_host, mqtt_port, mqtt_topic_prefix):
  30. caplog.set_level(logging.DEBUG)
  31. with unittest.mock.patch(
  32. "socket.create_connection"
  33. ) as create_socket_mock, unittest.mock.patch(
  34. "ssl.SSLContext.wrap_socket", autospec=True,
  35. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  36. "paho.mqtt.client.Client.loop_forever", autospec=True,
  37. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  38. "gi.repository.GLib.MainLoop.run"
  39. ) as glib_loop_mock:
  40. ssl_wrap_socket_mock.return_value.send = len
  41. systemctl_mqtt._run(
  42. mqtt_host=mqtt_host,
  43. mqtt_port=mqtt_port,
  44. mqtt_username=None,
  45. mqtt_password=None,
  46. mqtt_topic_prefix=mqtt_topic_prefix,
  47. )
  48. assert caplog.records[0].levelno == logging.INFO
  49. assert caplog.records[0].message == "connecting to MQTT broker {}:{}".format(
  50. mqtt_host, mqtt_port
  51. )
  52. # correct remote?
  53. assert create_socket_mock.call_count == 1
  54. create_socket_args, _ = create_socket_mock.call_args
  55. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  56. # ssl enabled?
  57. assert ssl_wrap_socket_mock.call_count == 1
  58. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  59. assert ssl_context.check_hostname is True
  60. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  61. # loop started?
  62. while threading.active_count() > 1:
  63. time.sleep(0.01)
  64. assert mqtt_loop_forever_mock.call_count == 1
  65. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  66. assert mqtt_client._tls_insecure is False
  67. # credentials
  68. assert mqtt_client._username is None
  69. assert mqtt_client._password is None
  70. # connect callback
  71. caplog.clear()
  72. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  73. with unittest.mock.patch(
  74. "paho.mqtt.client.Client.subscribe"
  75. ) as mqtt_subscribe_mock, unittest.mock.patch.object(
  76. mqtt_client._userdata, "acquire_shutdown_lock"
  77. ) as acquire_shutdown_lock_mock:
  78. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  79. acquire_shutdown_lock_mock.assert_called_once_with()
  80. mqtt_subscribe_mock.assert_called_once_with(mqtt_topic_prefix + "/poweroff")
  81. assert mqtt_client.on_message is None
  82. assert ( # pylint: disable=comparison-with-callable
  83. mqtt_client._on_message_filtered[mqtt_topic_prefix + "/poweroff"]
  84. == systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  85. "poweroff"
  86. ].mqtt_message_callback
  87. )
  88. assert caplog.records[0].levelno == logging.DEBUG
  89. assert caplog.records[0].message == "connected to MQTT broker {}:{}".format(
  90. mqtt_host, mqtt_port
  91. )
  92. assert caplog.records[1].levelno == logging.INFO
  93. assert caplog.records[1].message == "subscribing to {}".format(
  94. mqtt_topic_prefix + "/poweroff"
  95. )
  96. assert caplog.records[2].levelno == logging.DEBUG
  97. assert caplog.records[2].message == "registered MQTT callback for topic {}".format(
  98. mqtt_topic_prefix + "/poweroff"
  99. ) + " triggering {}".format(
  100. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"].action
  101. )
  102. # dbus loop started?
  103. glib_loop_mock.assert_called_once_with()
  104. # waited for mqtt loop to stop?
  105. assert mqtt_client._thread_terminate
  106. assert mqtt_client._thread is None
  107. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  108. @pytest.mark.parametrize("mqtt_port", [1833])
  109. @pytest.mark.parametrize("mqtt_username", ["me"])
  110. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  111. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  112. def test__run_authentication(
  113. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  114. ):
  115. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  116. "ssl.SSLContext.wrap_socket"
  117. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  118. "paho.mqtt.client.Client.loop_forever", autospec=True,
  119. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  120. "gi.repository.GLib.MainLoop.run"
  121. ):
  122. ssl_wrap_socket_mock.return_value.send = len
  123. systemctl_mqtt._run(
  124. mqtt_host=mqtt_host,
  125. mqtt_port=mqtt_port,
  126. mqtt_username=mqtt_username,
  127. mqtt_password=mqtt_password,
  128. mqtt_topic_prefix=mqtt_topic_prefix,
  129. )
  130. assert mqtt_loop_forever_mock.call_count == 1
  131. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  132. assert mqtt_client._username.decode() == mqtt_username
  133. if mqtt_password:
  134. assert mqtt_client._password.decode() == mqtt_password
  135. else:
  136. assert mqtt_client._password is None
  137. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  138. @pytest.mark.parametrize("mqtt_port", [1833])
  139. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  140. @pytest.fixture
  141. def initialized_mqtt_client(
  142. mqtt_host, mqtt_port, mqtt_topic_prefix
  143. ) -> paho.mqtt.client.Client:
  144. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  145. "ssl.SSLContext.wrap_socket",
  146. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  147. "paho.mqtt.client.Client.loop_forever", autospec=True,
  148. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  149. "gi.repository.GLib.MainLoop.run"
  150. ):
  151. ssl_wrap_socket_mock.return_value.send = len
  152. systemctl_mqtt._run(
  153. mqtt_host=mqtt_host,
  154. mqtt_port=mqtt_port,
  155. mqtt_username=None,
  156. mqtt_password=None,
  157. mqtt_topic_prefix=mqtt_topic_prefix,
  158. )
  159. while threading.active_count() > 1:
  160. time.sleep(0.01)
  161. assert mqtt_loop_forever_mock.call_count == 1
  162. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  163. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  164. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  165. return mqtt_client
  166. # pylint: disable=redefined-outer-name
  167. def test__client_handle_message(
  168. caplog, initialized_mqtt_client: paho.mqtt.client.Client
  169. ):
  170. caplog.set_level(logging.DEBUG)
  171. settings = initialized_mqtt_client._userdata # type: systemctl_mqtt._Settings
  172. poweroff_message = MQTTMessage(
  173. topic=settings.mqtt_topic_prefix.encode() + b"/poweroff"
  174. )
  175. with unittest.mock.patch.object(
  176. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  177. ) as poweroff_action_mock:
  178. initialized_mqtt_client._handle_on_message(poweroff_message)
  179. poweroff_action_mock.assert_called_once_with()
  180. assert all(r.levelno == logging.DEBUG for r in caplog.records)
  181. assert caplog.records[0].message == "received topic={} payload=b''".format(
  182. poweroff_message.topic
  183. )
  184. assert caplog.records[1].message.startswith("executing action poweroff")
  185. assert caplog.records[2].message.startswith("completed action poweroff")
  186. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  187. @pytest.mark.parametrize("mqtt_port", [1833])
  188. @pytest.mark.parametrize("mqtt_password", ["secret"])
  189. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  190. with unittest.mock.patch("paho.mqtt.client.Client"):
  191. with pytest.raises(ValueError):
  192. systemctl_mqtt._run(
  193. mqtt_host=mqtt_host,
  194. mqtt_port=mqtt_port,
  195. mqtt_username=None,
  196. mqtt_password=mqtt_password,
  197. mqtt_topic_prefix="prefix",
  198. )
  199. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  200. @pytest.mark.parametrize("payload", [b"", b"junk"])
  201. def test_mqtt_message_callback_poweroff(caplog, mqtt_topic: str, payload: bytes):
  202. message = MQTTMessage(topic=mqtt_topic.encode())
  203. message.payload = payload
  204. with unittest.mock.patch.object(
  205. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  206. ) as action_mock, caplog.at_level(logging.DEBUG):
  207. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  208. "poweroff"
  209. ].mqtt_message_callback(
  210. None, None, message # type: ignore
  211. )
  212. action_mock.assert_called_once_with()
  213. assert len(caplog.records) == 3
  214. assert caplog.records[0].levelno == logging.DEBUG
  215. assert caplog.records[0].message == (
  216. "received topic={} payload={!r}".format(mqtt_topic, payload)
  217. )
  218. assert caplog.records[1].levelno == logging.DEBUG
  219. assert caplog.records[1].message.startswith(
  220. "executing action {} ({!r})".format("poweroff", action_mock)
  221. )
  222. assert caplog.records[2].levelno == logging.DEBUG
  223. assert caplog.records[2].message.startswith(
  224. "completed action {} ({!r})".format("poweroff", action_mock)
  225. )
  226. @pytest.mark.parametrize("mqtt_topic", ["system/command/poweroff"])
  227. @pytest.mark.parametrize("payload", [b"", b"junk"])
  228. def test_mqtt_message_callback_poweroff_retained(
  229. caplog, mqtt_topic: str, payload: bytes
  230. ):
  231. message = MQTTMessage(topic=mqtt_topic.encode())
  232. message.payload = payload
  233. message.retain = True
  234. with unittest.mock.patch.object(
  235. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING["poweroff"], "action",
  236. ) as action_mock, caplog.at_level(logging.DEBUG):
  237. systemctl_mqtt._MQTT_TOPIC_SUFFIX_ACTION_MAPPING[
  238. "poweroff"
  239. ].mqtt_message_callback(
  240. None, None, message # type: ignore
  241. )
  242. action_mock.assert_not_called()
  243. assert len(caplog.records) == 2
  244. assert caplog.records[0].levelno == logging.DEBUG
  245. assert caplog.records[0].message == (
  246. "received topic={} payload={!r}".format(mqtt_topic, payload)
  247. )
  248. assert caplog.records[1].levelno == logging.INFO
  249. assert caplog.records[1].message == "ignoring retained message"
  250. def test_shutdown_lock():
  251. settings = systemctl_mqtt._Settings(mqtt_topic_prefix="any")
  252. lock_fd = unittest.mock.MagicMock()
  253. with unittest.mock.patch(
  254. "systemctl_mqtt._get_login_manager"
  255. ) as get_login_manager_mock:
  256. get_login_manager_mock.return_value.Inhibit.return_value = lock_fd
  257. settings.acquire_shutdown_lock()
  258. get_login_manager_mock.return_value.Inhibit.assert_called_once_with(
  259. "shutdown", "systemctl-mqtt", "Report shutdown via MQTT", "delay",
  260. )
  261. assert settings._shutdown_lock == lock_fd
  262. # https://dbus.freedesktop.org/doc/dbus-python/dbus.types.html#dbus.types.UnixFd.take
  263. lock_fd.take.return_value = "fdnum"
  264. with unittest.mock.patch("os.close") as close_mock:
  265. settings.release_shutdown_lock()
  266. close_mock.assert_called_once_with("fdnum")