瀏覽代碼

validate inputs received via mqtt

Fabian Peter Hammerle 4 年之前
父節點
當前提交
9459de8a2b
共有 2 個文件被更改,包括 236 次插入10 次删除
  1. 34 10
      intertechno_cc1101_mqtt/__init__.py
  2. 202 0
      tests/test_mqtt.py

+ 34 - 10
intertechno_cc1101_mqtt/__init__.py

@@ -21,18 +21,42 @@ def _mqtt_on_message(
         _LOGGER.warning("ignoring retained message")
         return
     topic_split = message.topic.split("/")
-    address = int(topic_split[1])
-    button_index = int(topic_split[2])
-    remote_control = intertechno_cc1101.RemoteControl(address=address)
-    # https://www.home-assistant.io/integrations/switch.mqtt/#payload_on
-    if message.payload.upper() == b"ON":
-        remote_control.turn_on(button_index=button_index)
-    elif message.payload.upper() == b"OFF":
-        remote_control.turn_off(button_index=button_index)
-    else:
+    try:
+        address = int(topic_split[1])
+    except ValueError:
+        _LOGGER.warning(
+            "failed to parse address %r, expected integer; ignoring message",
+            topic_split[1],
+        )
+        return
+    try:
+        button_index = int(topic_split[2])
+    except ValueError:
         _LOGGER.warning(
-            "unexpected payload %r; expected 'ON' or 'OFF'", message.payload
+            "failed to parse button index %r, expected integer; ignoring message",
+            topic_split[2],
         )
+        return
+    try:
+        remote_control = intertechno_cc1101.RemoteControl(address=address)
+    except AssertionError:
+        _LOGGER.warning(
+            "failed to initialize remote control, invalid address? ignoring message",
+            exc_info=True,
+        )
+        return
+    # https://www.home-assistant.io/integrations/switch.mqtt/#payload_on
+    try:
+        if message.payload.upper() == b"ON":
+            remote_control.turn_on(button_index=button_index)
+        elif message.payload.upper() == b"OFF":
+            remote_control.turn_off(button_index=button_index)
+        else:
+            _LOGGER.warning(
+                "unexpected payload %r; expected 'ON' or 'OFF'", message.payload
+            )
+    except Exception:  # pylint: disable=broad-except; invalid perms? spi error? invalid button index?
+        _LOGGER.error("failed to send signal", exc_info=True)
 
 
 def _mqtt_on_connect(

+ 202 - 0
tests/test_mqtt.py

@@ -0,0 +1,202 @@
+import logging
+import unittest.mock
+
+import pytest
+from paho.mqtt.client import MQTT_ERR_QUEUE_SIZE, MQTT_ERR_SUCCESS, MQTTMessage, Client
+
+import intertechno_cc1101_mqtt
+
+# pylint: disable=protected-access
+
+
+@pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
+@pytest.mark.parametrize("mqtt_port", [1833])
+def test__run(caplog, mqtt_host, mqtt_port):
+    with unittest.mock.patch(
+        "paho.mqtt.client.Client"
+    ) as mqtt_client_mock, caplog.at_level(logging.DEBUG):
+        intertechno_cc1101_mqtt._run(
+            mqtt_host=mqtt_host,
+            mqtt_port=mqtt_port,
+            mqtt_username=None,
+            mqtt_password=None,
+        )
+    mqtt_client_mock.assert_called_once_with()
+    assert not mqtt_client_mock().username_pw_set.called
+    mqtt_client_mock().connect.assert_called_once_with(host=mqtt_host, port=mqtt_port)
+    mqtt_client_mock().socket().getpeername.return_value = (mqtt_host, mqtt_port)
+    with caplog.at_level(logging.DEBUG):
+        mqtt_client_mock().on_connect(mqtt_client_mock(), None, {}, 0)
+    # pylint: disable=comparison-with-callable
+    assert mqtt_client_mock().on_message == intertechno_cc1101_mqtt._mqtt_on_message
+    mqtt_client_mock().subscribe.assert_called_once_with("intertechno-cc1101/+/+/set")
+    mqtt_client_mock().loop_forever.assert_called_once_with()
+    assert caplog.record_tuples == [
+        (
+            "intertechno_cc1101_mqtt",
+            logging.INFO,
+            "connecting to MQTT broker {}:{}".format(mqtt_host, mqtt_port),
+        ),
+        (
+            "intertechno_cc1101_mqtt",
+            logging.DEBUG,
+            "connected to MQTT broker {}:{}".format(mqtt_host, mqtt_port),
+        ),
+        (
+            "intertechno_cc1101_mqtt",
+            logging.INFO,
+            "subscribing to MQTT topic 'intertechno-cc1101/+/+/set'",
+        ),
+    ]
+
+
+@pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
+@pytest.mark.parametrize("mqtt_port", [1833])
+@pytest.mark.parametrize("mqtt_username", ["me"])
+@pytest.mark.parametrize("mqtt_password", [None, "secret"])
+def test__run_authentication(mqtt_host, mqtt_port, mqtt_username, mqtt_password):
+    with unittest.mock.patch("paho.mqtt.client.Client") as mqtt_client_mock:
+        intertechno_cc1101_mqtt._run(
+            mqtt_host=mqtt_host,
+            mqtt_port=mqtt_port,
+            mqtt_username=mqtt_username,
+            mqtt_password=mqtt_password,
+        )
+    mqtt_client_mock.assert_called_once_with()
+    mqtt_client_mock().username_pw_set.assert_called_once_with(
+        username=mqtt_username, password=mqtt_password
+    )
+
+
+@pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
+@pytest.mark.parametrize("mqtt_port", [1833])
+@pytest.mark.parametrize("mqtt_password", ["secret"])
+def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_password):
+    with unittest.mock.patch("paho.mqtt.client.Client"):
+        with pytest.raises(ValueError):
+            intertechno_cc1101_mqtt._run(
+                mqtt_host=mqtt_host,
+                mqtt_port=mqtt_port,
+                mqtt_username=None,
+                mqtt_password=mqtt_password,
+            )
+
+
+@pytest.mark.parametrize(
+    ("topic", "address"),
+    (
+        (b"intertechno-cc1101/12345678/0/set", 12345678),
+        (b"intertechno-cc1101/1234/0/set", 1234),
+    ),
+)
+def test__mqtt_on_message_address(topic, address):
+    message = MQTTMessage(topic=topic)
+    message.payload = b"ON"
+    with unittest.mock.patch("intertechno_cc1101.RemoteControl") as remote_control_mock:
+        intertechno_cc1101_mqtt._mqtt_on_message("dummy", None, message)
+    remote_control_mock.assert_called_once_with(address=address)
+
+
+@pytest.mark.parametrize(
+    ("topic", "address_str"),
+    (
+        (b"intertechno-cc1101/abcdef/0/set", "abcdef"),
+        (b"intertechno-cc1101//0/set", ""),
+    ),
+)
+def test__mqtt_on_message_invalid_address(caplog, topic, address_str):
+    message = MQTTMessage(topic=topic)
+    message.payload = b"ON"
+    with unittest.mock.patch("intertechno_cc1101.RemoteControl") as remote_control_mock:
+        with caplog.at_level(logging.WARNING):
+            intertechno_cc1101_mqtt._mqtt_on_message("dummy", None, message)
+    remote_control_mock.assert_not_called()
+    assert caplog.record_tuples == [
+        (
+            "intertechno_cc1101_mqtt",
+            logging.WARNING,
+            "failed to parse address {!r}, expected integer; ignoring message".format(
+                address_str
+            ),
+        )
+    ]
+
+
+@pytest.mark.parametrize(
+    ("topic", "button_index"),
+    ((b"intertechno-cc1101/12345678/0/set", 0), (b"intertechno-cc1101/1234/7/set", 7)),
+)
+@pytest.mark.parametrize(
+    ("payload", "turn_on"),
+    ((b"ON", True), (b"On", True), (b"on", True), (b"OFF", False), (b"off", False)),
+)
+def test__mqtt_on_message_button_index_action(topic, button_index, payload, turn_on):
+    message = MQTTMessage(topic=topic)
+    message.payload = payload
+    with unittest.mock.patch("intertechno_cc1101.RemoteControl") as remote_control_mock:
+        intertechno_cc1101_mqtt._mqtt_on_message("dummy", None, message)
+    if turn_on:
+        remote_control_mock().turn_on.assert_called_once_with(button_index=button_index)
+        remote_control_mock().turn_off.assert_not_called()
+    else:
+        remote_control_mock().turn_off.assert_called_once_with(
+            button_index=button_index
+        )
+        remote_control_mock().turn_on.assert_not_called()
+
+
+@pytest.mark.parametrize(
+    ("topic", "button_index_str"),
+    (
+        (b"intertechno-cc1101/12345678/abc/set", "abc"),
+        (b"intertechno-cc1101/12345678//set", ""),
+    ),
+)
+def test__mqtt_on_message_invalid_button_index(caplog, topic, button_index_str):
+    message = MQTTMessage(topic=topic)
+    message.payload = b"ON"
+    with unittest.mock.patch("intertechno_cc1101.RemoteControl") as remote_control_mock:
+        with caplog.at_level(logging.WARNING):
+            intertechno_cc1101_mqtt._mqtt_on_message("dummy", None, message)
+    remote_control_mock().turn_on.assert_not_called()
+    remote_control_mock().turn_off.assert_not_called()
+    assert caplog.record_tuples == [
+        (
+            "intertechno_cc1101_mqtt",
+            logging.WARNING,
+            "failed to parse button index {!r}, expected integer; ignoring message".format(
+                button_index_str
+            ),
+        )
+    ]
+
+
+@pytest.mark.parametrize(
+    "topic", (b"intertechno-cc1101/123456789/0/set", b"intertechno-cc1101/-21/0/set")
+)
+def test__mqtt_on_message_remote_init_failed(caplog, topic):
+    message = MQTTMessage(topic=topic)
+    message.payload = b"ON"
+    with caplog.at_level(logging.WARNING):
+        intertechno_cc1101_mqtt._mqtt_on_message("dummy", None, message)
+    assert len(caplog.records) == 1
+    assert caplog.records[0].levelno == logging.WARNING
+    assert (
+        caplog.records[0].message
+        == "failed to initialize remote control, invalid address? ignoring message"
+    )
+    assert isinstance(caplog.records[0].exc_info[1], AssertionError)
+
+
+def test__mqtt_on_message_transmission_failed(caplog):
+    message = MQTTMessage(topic=b"intertechno-cc1101/12345678/3/set")
+    message.payload = b"ON"
+    with unittest.mock.patch(
+        "cc1101.CC1101.__enter__",
+        side_effect=FileNotFoundError("[Errno 2] No such file or directory"),
+    ), caplog.at_level(logging.ERROR):
+        intertechno_cc1101_mqtt._mqtt_on_message("dummy", None, message)
+    assert len(caplog.records) == 1
+    assert caplog.records[0].levelno == logging.ERROR
+    assert caplog.records[0].message == "failed to send signal"
+    assert isinstance(caplog.records[0].exc_info[1], FileNotFoundError)