test_mqtt.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # systemctl-mqtt - MQTT client triggering shutdown on systemd-based systems
  2. #
  3. # Copyright (C) 2020 Fabian Peter Hammerle <fabian@hammerle.me>
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. import logging
  18. import unittest.mock
  19. import pytest
  20. from paho.mqtt.client import MQTTMessage
  21. import systemctl_mqtt
  22. # pylint: disable=protected-access
  23. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  24. @pytest.mark.parametrize("mqtt_port", [1833])
  25. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  26. def test__run(mqtt_host, mqtt_port, mqtt_topic_prefix):
  27. with unittest.mock.patch(
  28. "socket.create_connection"
  29. ) as create_socket_mock, unittest.mock.patch(
  30. "ssl.SSLContext.wrap_socket", autospec=True,
  31. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  32. "paho.mqtt.client.Client.loop_forever", autospec=True,
  33. ) as mqtt_loop_forever_mock, unittest.mock.patch(
  34. "systemctl_mqtt._mqtt_on_message"
  35. ) as message_handler_mock:
  36. ssl_wrap_socket_mock.return_value.send = len
  37. systemctl_mqtt._run(
  38. mqtt_host=mqtt_host,
  39. mqtt_port=mqtt_port,
  40. mqtt_username=None,
  41. mqtt_password=None,
  42. mqtt_topic_prefix=mqtt_topic_prefix,
  43. )
  44. # correct remote?
  45. assert create_socket_mock.call_count == 1
  46. create_socket_args, _ = create_socket_mock.call_args
  47. assert create_socket_args[0] == (mqtt_host, mqtt_port)
  48. # ssl enabled?
  49. assert ssl_wrap_socket_mock.call_count == 1
  50. ssl_context = ssl_wrap_socket_mock.call_args[0][0] # self
  51. assert ssl_context.check_hostname is True
  52. assert ssl_wrap_socket_mock.call_args[1]["server_hostname"] == mqtt_host
  53. # loop started?
  54. assert mqtt_loop_forever_mock.call_count == 1
  55. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  56. assert mqtt_client._tls_insecure is False
  57. # credentials
  58. assert mqtt_client._username is None
  59. assert mqtt_client._password is None
  60. # connect callback
  61. mqtt_client.socket().getpeername.return_value = (mqtt_host, mqtt_port)
  62. with unittest.mock.patch(
  63. "paho.mqtt.client.Client.subscribe"
  64. ) as mqtt_subscribe_mock:
  65. mqtt_client.on_connect(mqtt_client, mqtt_client._userdata, {}, 0)
  66. mqtt_subscribe_mock.assert_called_once_with(mqtt_topic_prefix + "/poweroff")
  67. # message callback
  68. test_message = MQTTMessage(topic=b"test")
  69. message_handler_mock.assert_not_called()
  70. mqtt_client._handle_on_message(test_message)
  71. message_handler_mock.assert_called_once_with(
  72. mqtt_client, mqtt_client._userdata, test_message
  73. )
  74. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  75. @pytest.mark.parametrize("mqtt_port", [1833])
  76. @pytest.mark.parametrize("mqtt_username", ["me"])
  77. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  78. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  79. def test__run_authentication(
  80. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  81. ):
  82. with unittest.mock.patch("socket.create_connection"), unittest.mock.patch(
  83. "ssl.SSLContext.wrap_socket"
  84. ) as ssl_wrap_socket_mock, unittest.mock.patch(
  85. "paho.mqtt.client.Client.loop_forever", autospec=True,
  86. ) as mqtt_loop_forever_mock:
  87. ssl_wrap_socket_mock.return_value.send = len
  88. systemctl_mqtt._run(
  89. mqtt_host=mqtt_host,
  90. mqtt_port=mqtt_port,
  91. mqtt_username=mqtt_username,
  92. mqtt_password=mqtt_password,
  93. mqtt_topic_prefix=mqtt_topic_prefix,
  94. )
  95. assert mqtt_loop_forever_mock.call_count == 1
  96. (mqtt_client,) = mqtt_loop_forever_mock.call_args[0]
  97. assert mqtt_client._username.decode() == mqtt_username
  98. if mqtt_password:
  99. assert mqtt_client._password.decode() == mqtt_password
  100. else:
  101. assert mqtt_client._password is None
  102. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  103. @pytest.mark.parametrize("mqtt_port", [1833])
  104. @pytest.mark.parametrize("mqtt_password", ["secret"])
  105. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  106. with unittest.mock.patch("paho.mqtt.client.Client"):
  107. with pytest.raises(ValueError):
  108. systemctl_mqtt._run(
  109. mqtt_host=mqtt_host,
  110. mqtt_port=mqtt_port,
  111. mqtt_username=None,
  112. mqtt_password=mqtt_password,
  113. mqtt_topic_prefix="prefix",
  114. )
  115. @pytest.mark.parametrize("mqtt_topic_prefix", ["system/command"])
  116. @pytest.mark.parametrize("payload", [b"", b"junk"])
  117. def test__mqtt_on_message_poweroff(caplog, mqtt_topic_prefix: str, payload: bytes):
  118. mqtt_topic = mqtt_topic_prefix + "/poweroff"
  119. message = MQTTMessage(topic=mqtt_topic.encode())
  120. message.payload = payload
  121. settings = systemctl_mqtt._Settings(mqtt_topic_prefix=mqtt_topic_prefix)
  122. action_mock = unittest.mock.MagicMock()
  123. settings.mqtt_topic_action_mapping[mqtt_topic] = action_mock # functools.partial
  124. with caplog.at_level(logging.DEBUG):
  125. systemctl_mqtt._mqtt_on_message(
  126. None, settings, message,
  127. )
  128. assert len(caplog.records) == 3
  129. assert caplog.records[0].levelno == logging.DEBUG
  130. assert caplog.records[0].message == (
  131. "received topic={} payload={!r}".format(mqtt_topic, payload)
  132. )
  133. assert caplog.records[1].levelno == logging.DEBUG
  134. assert caplog.records[1].message.startswith(
  135. "executing action {!r}".format(action_mock)
  136. )
  137. assert caplog.records[2].levelno == logging.DEBUG
  138. assert caplog.records[2].message.startswith(
  139. "completed action {!r}".format(action_mock)
  140. )
  141. action_mock.assert_called_once_with()
  142. @pytest.mark.parametrize(
  143. ("topic", "payload"), [("system/poweroff", b""), ("system/poweroff", "payload"),],
  144. )
  145. def test__mqtt_on_message_ignored(
  146. caplog, topic: str, payload: bytes,
  147. ):
  148. message = MQTTMessage(topic=topic.encode())
  149. message.payload = payload
  150. settings = systemctl_mqtt._Settings(mqtt_topic_prefix="system/command")
  151. settings.mqtt_topic_action_mapping = {} # provoke KeyError on access
  152. with caplog.at_level(logging.DEBUG):
  153. systemctl_mqtt._mqtt_on_message(
  154. None, settings, message,
  155. )
  156. assert len(caplog.records) == 2
  157. assert caplog.records[0].levelno == logging.DEBUG
  158. assert caplog.records[0].message == (
  159. "received topic={} payload={!r}".format(topic, payload)
  160. )
  161. assert caplog.records[1].levelno == logging.WARNING
  162. assert caplog.records[1].message == "unexpected topic {}".format(topic)
  163. @pytest.mark.parametrize(
  164. ("topic", "payload"), [("system/command/poweroff", b"")],
  165. )
  166. def test__mqtt_on_message_ignored_retained(
  167. caplog, topic: str, payload: bytes,
  168. ):
  169. message = MQTTMessage(topic=topic.encode())
  170. message.payload = payload
  171. message.retain = True
  172. settings = systemctl_mqtt._Settings(mqtt_topic_prefix="system/command")
  173. settings.mqtt_topic_action_mapping = {} # provoke KeyError on access
  174. with caplog.at_level(logging.DEBUG):
  175. systemctl_mqtt._mqtt_on_message(
  176. None, settings, message,
  177. )
  178. assert len(caplog.records) == 2
  179. assert caplog.records[0].levelno == logging.DEBUG
  180. assert caplog.records[0].message == (
  181. "received topic={} payload={!r}".format(topic, payload)
  182. )
  183. assert caplog.records[1].levelno == logging.INFO
  184. assert caplog.records[1].message == "ignoring retained message"