test_mqtt.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import logging
  2. import unittest.mock
  3. import pytest
  4. from paho.mqtt.client import MQTTMessage
  5. import systemctl_mqtt
  6. # pylint: disable=protected-access
  7. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  8. @pytest.mark.parametrize("mqtt_port", [1833])
  9. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host", "system/command"])
  10. def test__run(mqtt_host, mqtt_port, mqtt_topic_prefix):
  11. with unittest.mock.patch(
  12. "paho.mqtt.client.Client"
  13. ) as mqtt_client_mock, unittest.mock.patch(
  14. "systemctl_mqtt._mqtt_on_message"
  15. ) as message_handler_mock:
  16. systemctl_mqtt._run(
  17. mqtt_host=mqtt_host,
  18. mqtt_port=mqtt_port,
  19. mqtt_username=None,
  20. mqtt_password=None,
  21. mqtt_topic_prefix=mqtt_topic_prefix,
  22. )
  23. mqtt_client_mock.assert_called_once()
  24. init_args, init_kwargs = mqtt_client_mock.call_args
  25. assert not init_args
  26. assert len(init_kwargs) == 1
  27. settings = init_kwargs["userdata"]
  28. assert isinstance(settings, systemctl_mqtt._Settings)
  29. assert mqtt_topic_prefix + "/poweroff" in settings.mqtt_topic_action_mapping
  30. assert not mqtt_client_mock().username_pw_set.called
  31. mqtt_client_mock().connect.assert_called_once_with(host=mqtt_host, port=mqtt_port)
  32. mqtt_client_mock().socket().getpeername.return_value = (mqtt_host, mqtt_port)
  33. mqtt_client_mock().on_connect(mqtt_client_mock(), settings, {}, 0)
  34. mqtt_client_mock().subscribe.assert_called_once_with(
  35. mqtt_topic_prefix + "/poweroff"
  36. )
  37. mqtt_client_mock().on_message(mqtt_client_mock(), settings, "message")
  38. message_handler_mock.assert_called_once()
  39. mqtt_client_mock().loop_forever.assert_called_once_with()
  40. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  41. @pytest.mark.parametrize("mqtt_port", [1833])
  42. @pytest.mark.parametrize("mqtt_username", ["me"])
  43. @pytest.mark.parametrize("mqtt_password", [None, "secret"])
  44. @pytest.mark.parametrize("mqtt_topic_prefix", ["systemctl/host"])
  45. def test__run_authentication(
  46. mqtt_host, mqtt_port, mqtt_username, mqtt_password, mqtt_topic_prefix
  47. ):
  48. with unittest.mock.patch("paho.mqtt.client.Client") as mqtt_client_mock:
  49. systemctl_mqtt._run(
  50. mqtt_host=mqtt_host,
  51. mqtt_port=mqtt_port,
  52. mqtt_username=mqtt_username,
  53. mqtt_password=mqtt_password,
  54. mqtt_topic_prefix=mqtt_topic_prefix,
  55. )
  56. mqtt_client_mock.assert_called_once()
  57. init_args, init_kwargs = mqtt_client_mock.call_args
  58. assert not init_args
  59. assert set(init_kwargs.keys()) == {"userdata"}
  60. mqtt_client_mock().username_pw_set.assert_called_once_with(
  61. username=mqtt_username, password=mqtt_password,
  62. )
  63. @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
  64. @pytest.mark.parametrize("mqtt_port", [1833])
  65. @pytest.mark.parametrize("mqtt_password", ["secret"])
  66. def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
  67. with unittest.mock.patch("paho.mqtt.client.Client"):
  68. with pytest.raises(ValueError):
  69. systemctl_mqtt._run(
  70. mqtt_host=mqtt_host,
  71. mqtt_port=mqtt_port,
  72. mqtt_username=None,
  73. mqtt_password=mqtt_password,
  74. mqtt_topic_prefix="prefix",
  75. )
  76. @pytest.mark.parametrize("mqtt_topic_prefix", ["system/command"])
  77. @pytest.mark.parametrize("payload", [b"", b"junk"])
  78. def test__mqtt_on_message_poweroff(caplog, mqtt_topic_prefix: str, payload: bytes):
  79. mqtt_topic = mqtt_topic_prefix + "/poweroff"
  80. message = MQTTMessage(topic=mqtt_topic.encode())
  81. message.payload = payload
  82. settings = systemctl_mqtt._Settings(mqtt_topic_prefix=mqtt_topic_prefix)
  83. action_mock = unittest.mock.MagicMock()
  84. settings.mqtt_topic_action_mapping[mqtt_topic] = action_mock # functools.partial
  85. with caplog.at_level(logging.DEBUG):
  86. systemctl_mqtt._mqtt_on_message(
  87. None, settings, message,
  88. )
  89. assert len(caplog.records) == 1
  90. assert caplog.records[0].levelno == logging.DEBUG
  91. assert caplog.records[0].message == (
  92. "received topic={} payload={!r}".format(mqtt_topic, payload)
  93. )
  94. action_mock.assert_called_once_with()
  95. @pytest.mark.parametrize(
  96. ("topic", "payload"), [("system/poweroff", b""), ("system/poweroff", "payload"),],
  97. )
  98. def test__mqtt_on_message_ignored(
  99. caplog, topic: str, payload: bytes,
  100. ):
  101. message = MQTTMessage(topic=topic.encode())
  102. message.payload = payload
  103. settings = systemctl_mqtt._Settings(mqtt_topic_prefix="system/command")
  104. settings.mqtt_topic_action_mapping = {} # provoke KeyError on access
  105. with caplog.at_level(logging.DEBUG):
  106. systemctl_mqtt._mqtt_on_message(
  107. None, settings, message,
  108. )
  109. assert len(caplog.records) == 2
  110. assert caplog.records[0].levelno == logging.DEBUG
  111. assert caplog.records[0].message == (
  112. "received topic={} payload={!r}".format(topic, payload)
  113. )
  114. assert caplog.records[1].levelno == logging.WARNING
  115. assert caplog.records[1].message == "unexpected topic {}".format(topic)
  116. @pytest.mark.parametrize(
  117. ("topic", "payload"), [("system/command/poweroff", b"")],
  118. )
  119. def test__mqtt_on_message_ignored_retained(
  120. caplog, topic: str, payload: bytes,
  121. ):
  122. message = MQTTMessage(topic=topic.encode())
  123. message.payload = payload
  124. message.retain = True
  125. settings = systemctl_mqtt._Settings(mqtt_topic_prefix="system/command")
  126. settings.mqtt_topic_action_mapping = {} # provoke KeyError on access
  127. with caplog.at_level(logging.DEBUG):
  128. systemctl_mqtt._mqtt_on_message(
  129. None, settings, message,
  130. )
  131. assert len(caplog.records) == 2
  132. assert caplog.records[0].levelno == logging.DEBUG
  133. assert caplog.records[0].message == (
  134. "received topic={} payload={!r}".format(topic, payload)
  135. )
  136. assert caplog.records[1].levelno == logging.INFO
  137. assert caplog.records[1].message == "ignoring retained message"