Browse Source

added command-line parameter `--retries` to adjust maximum number of attempts to send command

https://github.com/fphammerle/switchbot-mqtt/pull/40
Fabian Peter Hammerle 2 years ago
parent
commit
cefa755ad0

+ 4 - 0
CHANGELOG.md

@@ -5,6 +5,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
 ## [Unreleased]
+### Added
+- Command-line argument `--retries` to alter maximum number of attempts to send a command
+  to a SwitchBot device (default unchanged)
+
 ### Fixed
 - dockerfile: split `pipenv install` into two stages to speed up image builds
 - dockerfile: `chmod` files copied from host to no longer require `o=rX` perms on host

+ 31 - 11
switchbot_mqtt/__init__.py

@@ -18,6 +18,7 @@
 
 import abc
 import argparse
+import collections
 import enum
 import logging
 import pathlib
@@ -45,11 +46,15 @@ def _mac_address_valid(mac_address: str) -> bool:
     return _MAC_ADDRESS_REGEX.match(mac_address.lower()) is not None
 
 
+_MQTTCallbackUserdata = collections.namedtuple("_MQTTCallbackUserdata", ["retry_count"])
+
+
 class _MQTTControlledActor(abc.ABC):
     MQTT_COMMAND_TOPIC_LEVELS = NotImplemented  # type: typing.List[_MQTTTopicLevel]
     MQTT_STATE_TOPIC_LEVELS = NotImplemented  # type: typing.List[_MQTTTopicLevel]
 
-    def __init__(self, mac_address: str) -> None:
+    @abc.abstractmethod
+    def __init__(self, mac_address: str, retry_count: int) -> None:
         self._mac_address = mac_address
 
     @abc.abstractmethod
@@ -62,7 +67,7 @@ class _MQTTControlledActor(abc.ABC):
     def _mqtt_command_callback(
         cls,
         mqtt_client: paho.mqtt.client.Client,
-        userdata: None,
+        userdata: _MQTTCallbackUserdata,
         message: paho.mqtt.client.MQTTMessage,
     ) -> None:
         # pylint: disable=unused-argument; callback
@@ -88,7 +93,7 @@ class _MQTTControlledActor(abc.ABC):
         if not _mac_address_valid(mac_address):
             _LOGGER.warning("invalid mac address %s", mac_address)
             return
-        cls(mac_address=mac_address).execute_command(
+        cls(mac_address=mac_address, retry_count=userdata.retry_count).execute_command(
             mqtt_message_payload=message.payload, mqtt_client=mqtt_client
         )
 
@@ -101,7 +106,8 @@ class _MQTTControlledActor(abc.ABC):
         _LOGGER.info("subscribing to MQTT topic %r", command_topic)
         mqtt_client.subscribe(command_topic)
         mqtt_client.message_callback_add(
-            sub=command_topic, callback=cls._mqtt_command_callback
+            sub=command_topic,
+            callback=cls._mqtt_command_callback,
         )
 
     def _mqtt_publish(
@@ -154,9 +160,9 @@ class _ButtonAutomator(_MQTTControlledActor):
         "state",
     ]
 
-    def __init__(self, mac_address) -> None:
-        self._device = switchbot.Switchbot(mac=mac_address)
-        super().__init__(mac_address=mac_address)
+    def __init__(self, mac_address: str, retry_count: int) -> None:
+        self._device = switchbot.Switchbot(mac=mac_address, retry_count=retry_count)
+        super().__init__(mac_address=mac_address, retry_count=retry_count)
 
     def execute_command(
         self, mqtt_message_payload: bytes, mqtt_client: paho.mqtt.client.Client
@@ -199,9 +205,11 @@ class _CurtainMotor(_MQTTControlledActor):
         "state",
     ]
 
-    def __init__(self, mac_address) -> None:
-        self._device = switchbot.SwitchbotCurtain(mac=mac_address)
-        super().__init__(mac_address=mac_address)
+    def __init__(self, mac_address: str, retry_count: int) -> None:
+        self._device = switchbot.SwitchbotCurtain(
+            mac=mac_address, retry_count=retry_count
+        )
+        super().__init__(mac_address=mac_address, retry_count=retry_count)
 
     def execute_command(
         self, mqtt_message_payload: bytes, mqtt_client: paho.mqtt.client.Client
@@ -258,9 +266,12 @@ def _run(
     mqtt_port: int,
     mqtt_username: typing.Optional[str],
     mqtt_password: typing.Optional[str],
+    retry_count: int,
 ) -> None:
     # https://pypi.org/project/paho-mqtt/
-    mqtt_client = paho.mqtt.client.Client()
+    mqtt_client = paho.mqtt.client.Client(
+        userdata=_MQTTCallbackUserdata(retry_count=retry_count)
+    )
     mqtt_client.on_connect = _mqtt_on_connect
     _LOGGER.info("connecting to MQTT broker %s:%d", mqtt_host, mqtt_port)
     if mqtt_username:
@@ -294,6 +305,14 @@ def _main() -> None:
         dest="mqtt_password_path",
         help="stripping trailing newline",
     )
+    argparser.add_argument(
+        "--retries",
+        dest="retry_count",
+        type=int,
+        default=switchbot.DEFAULT_RETRY_COUNT,
+        help="Maximum number of attempts to send a command to a SwitchBot device"
+        " (default: %(default)d)",
+    )
     args = argparser.parse_args()
     if args.mqtt_password_path:
         # .read_text() replaces \r\n with \n
@@ -309,4 +328,5 @@ def _main() -> None:
         mqtt_port=args.mqtt_port,
         mqtt_username=args.mqtt_username,
         mqtt_password=mqtt_password,
+        retry_count=args.retry_count,
     )

+ 5 - 2
tests/test_actor_base.py

@@ -27,11 +27,14 @@ import switchbot_mqtt
 def test_abstract():
     with pytest.raises(TypeError, match=r"\babstract class\b"):
         # pylint: disable=abstract-class-instantiated
-        switchbot_mqtt._MQTTControlledActor(mac_address=None)
+        switchbot_mqtt._MQTTControlledActor(mac_address=None, retry_count=21)
 
 
 def test_execute_command_abstract():
     class _ActorMock(switchbot_mqtt._MQTTControlledActor):
+        def __init__(self, mac_address: str, retry_count: int) -> None:
+            super().__init__(mac_address=mac_address, retry_count=retry_count)
+
         def execute_command(
             self, mqtt_message_payload: bytes, mqtt_client: paho.mqtt.client.Client
         ) -> None:
@@ -39,6 +42,6 @@ def test_execute_command_abstract():
                 mqtt_message_payload=mqtt_message_payload, mqtt_client=mqtt_client
             )
 
-    actor = _ActorMock(mac_address=None)
+    actor = _ActorMock(mac_address=None, retry_count=42)
     with pytest.raises(NotImplementedError):
         actor.execute_command(mqtt_message_payload=b"dummy", mqtt_client="dummy")

+ 17 - 1
tests/test_cli.py

@@ -22,6 +22,8 @@ import pytest
 
 import switchbot_mqtt
 
+# pylint: disable=too-many-arguments; these are tests, no API
+
 
 @pytest.mark.parametrize(
     (
@@ -30,6 +32,7 @@ import switchbot_mqtt
         "expected_mqtt_port",
         "expected_username",
         "expected_password",
+        "expected_retry_count",
     ),
     [
         (
@@ -38,6 +41,7 @@ import switchbot_mqtt
             1883,
             None,
             None,
+            3,
         ),
         (
             ["", "--mqtt-host", "mqtt-broker.local", "--mqtt-port", "8883"],
@@ -45,6 +49,7 @@ import switchbot_mqtt
             8883,
             None,
             None,
+            3,
         ),
         (
             ["", "--mqtt-host", "mqtt-broker.local", "--mqtt-username", "me"],
@@ -52,6 +57,7 @@ import switchbot_mqtt
             1883,
             "me",
             None,
+            3,
         ),
         (
             [
@@ -62,16 +68,24 @@ import switchbot_mqtt
                 "me",
                 "--mqtt-password",
                 "secret",
+                "--retries",
+                "21",
             ],
             "mqtt-broker.local",
             1883,
             "me",
             "secret",
+            21,
         ),
     ],
 )
 def test__main(
-    argv, expected_mqtt_host, expected_mqtt_port, expected_username, expected_password
+    argv,
+    expected_mqtt_host,
+    expected_mqtt_port,
+    expected_username,
+    expected_password,
+    expected_retry_count,
 ):
     with unittest.mock.patch("switchbot_mqtt._run") as run_mock, unittest.mock.patch(
         "sys.argv", argv
@@ -83,6 +97,7 @@ def test__main(
         mqtt_port=expected_mqtt_port,
         mqtt_username=expected_username,
         mqtt_password=expected_password,
+        retry_count=expected_retry_count,
     )
 
 
@@ -123,6 +138,7 @@ def test__main_password_file(tmpdir, password_file_content, expected_password):
         mqtt_port=1883,
         mqtt_username="me",
         mqtt_password=expected_password,
+        retry_count=3,
     )
 
 

+ 43 - 19
tests/test_mqtt.py

@@ -26,11 +26,13 @@ from paho.mqtt.client import MQTT_ERR_QUEUE_SIZE, MQTT_ERR_SUCCESS, MQTTMessage,
 import switchbot_mqtt
 
 # pylint: disable=protected-access
+# pylint: disable=too-many-arguments; these are tests, no API
 
 
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
-def test__run(caplog, mqtt_host, mqtt_port):
+@pytest.mark.parametrize("retry_count", [3, 21])
+def test__run(caplog, mqtt_host, mqtt_port, retry_count):
     with unittest.mock.patch(
         "paho.mqtt.client.Client"
     ) as mqtt_client_mock, caplog.at_level(logging.DEBUG):
@@ -39,8 +41,11 @@ def test__run(caplog, mqtt_host, mqtt_port):
             mqtt_port=mqtt_port,
             mqtt_username=None,
             mqtt_password=None,
+            retry_count=retry_count,
         )
-    mqtt_client_mock.assert_called_once_with()
+    mqtt_client_mock.assert_called_once_with(
+        userdata=switchbot_mqtt._MQTTCallbackUserdata(retry_count=retry_count)
+    )
     assert not mqtt_client_mock().username_pw_set.called
     mqtt_client_mock().connect.assert_called_once_with(host=mqtt_host, port=mqtt_port)
     mqtt_client_mock().socket().getpeername.return_value = (mqtt_host, mqtt_port)
@@ -96,8 +101,11 @@ def test__run_authentication(mqtt_host, mqtt_port, mqtt_username, mqtt_password)
             mqtt_port=mqtt_port,
             mqtt_username=mqtt_username,
             mqtt_password=mqtt_password,
+            retry_count=7,
         )
-    mqtt_client_mock.assert_called_once_with()
+    mqtt_client_mock.assert_called_once_with(
+        userdata=switchbot_mqtt._MQTTCallbackUserdata(retry_count=7)
+    )
     mqtt_client_mock().username_pw_set.assert_called_once_with(
         username=mqtt_username, password=mqtt_password
     )
@@ -114,6 +122,7 @@ def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_passwor
                 mqtt_port=mqtt_port,
                 mqtt_username=None,
                 mqtt_password=mqtt_password,
+                retry_count=3,
             )
 
 
@@ -164,24 +173,27 @@ def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_passwor
         ),
     ],
 )
+@pytest.mark.parametrize("retry_count", (3, 42))
 def test__mqtt_command_callback(
     caplog,
     command_topic_levels: typing.List[switchbot_mqtt._MQTTTopicLevel],
     topic: bytes,
     payload: bytes,
     expected_mac_address: str,
+    retry_count: int,
 ):
     class _ActorMock(switchbot_mqtt._MQTTControlledActor):
         MQTT_COMMAND_TOPIC_LEVELS = command_topic_levels
 
-        def __init__(self, mac_address):
-            super().__init__(mac_address=mac_address)
+        def __init__(self, mac_address, retry_count):
+            super().__init__(mac_address=mac_address, retry_count=retry_count)
 
         def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
             pass
 
     message = MQTTMessage(topic=topic)
     message.payload = payload
+    callback_userdata = switchbot_mqtt._MQTTCallbackUserdata(retry_count=retry_count)
     with unittest.mock.patch.object(
         _ActorMock, "__init__", return_value=None
     ) as init_mock, unittest.mock.patch.object(
@@ -189,8 +201,10 @@ def test__mqtt_command_callback(
     ) as execute_command_mock, caplog.at_level(
         logging.DEBUG
     ):
-        _ActorMock._mqtt_command_callback("client_dummy", None, message)
-    init_mock.assert_called_once_with(mac_address=expected_mac_address)
+        _ActorMock._mqtt_command_callback("client_dummy", callback_userdata, message)
+    init_mock.assert_called_once_with(
+        mac_address=expected_mac_address, retry_count=retry_count
+    )
     execute_command_mock.assert_called_once_with(
         mqtt_client="client_dummy", mqtt_message_payload=payload
     )
@@ -217,8 +231,8 @@ def test__mqtt_command_callback_unexpected_topic(caplog, topic: bytes, payload:
             switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
         )
 
-        def __init__(self, mac_address):
-            super().__init__(mac_address=mac_address)
+        def __init__(self, mac_address, retry_count):
+            super().__init__(mac_address=mac_address, retry_count=retry_count)
 
         def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
             pass
@@ -232,7 +246,9 @@ def test__mqtt_command_callback_unexpected_topic(caplog, topic: bytes, payload:
     ) as execute_command_mock, caplog.at_level(
         logging.DEBUG
     ):
-        _ActorMock._mqtt_command_callback("client_dummy", None, message)
+        _ActorMock._mqtt_command_callback(
+            "client_dummy", switchbot_mqtt._MQTTCallbackUserdata(retry_count=3), message
+        )
     init_mock.assert_not_called()
     execute_command_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -258,8 +274,8 @@ def test__mqtt_command_callback_invalid_mac_address(
             switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
         )
 
-        def __init__(self, mac_address):
-            super().__init__(mac_address=mac_address)
+        def __init__(self, mac_address, retry_count):
+            super().__init__(mac_address=mac_address, retry_count=retry_count)
 
         def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
             pass
@@ -274,7 +290,11 @@ def test__mqtt_command_callback_invalid_mac_address(
     ) as execute_command_mock, caplog.at_level(
         logging.DEBUG
     ):
-        _ActorMock._mqtt_command_callback("client_dummy", None, message)
+        _ActorMock._mqtt_command_callback(
+            "client_dummy",
+            switchbot_mqtt._MQTTCallbackUserdata(retry_count=None),
+            message,
+        )
     init_mock.assert_not_called()
     execute_command_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -301,8 +321,8 @@ def test__mqtt_command_callback_ignore_retained(caplog, topic: bytes, payload: b
             switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
         )
 
-        def __init__(self, mac_address):
-            super().__init__(mac_address=mac_address)
+        def __init__(self, mac_address, retry_count):
+            super().__init__(mac_address=mac_address, retry_count=retry_count)
 
         def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
             pass
@@ -317,7 +337,11 @@ def test__mqtt_command_callback_ignore_retained(caplog, topic: bytes, payload: b
     ) as execute_command_mock, caplog.at_level(
         logging.DEBUG
     ):
-        _ActorMock._mqtt_command_callback("client_dummy", None, message)
+        _ActorMock._mqtt_command_callback(
+            "client_dummy",
+            switchbot_mqtt._MQTTCallbackUserdata(retry_count=None),
+            message,
+        )
     init_mock.assert_not_called()
     execute_command_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -360,8 +384,8 @@ def test__report_state(
     class _ActorMock(switchbot_mqtt._MQTTControlledActor):
         MQTT_STATE_TOPIC_LEVELS = state_topic_levels
 
-        def __init__(self, mac_address):
-            super().__init__(mac_address=mac_address)
+        def __init__(self, mac_address, retry_count):
+            super().__init__(mac_address=mac_address, retry_count=retry_count)
 
         def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
             pass
@@ -369,7 +393,7 @@ def test__report_state(
     mqtt_client_mock = unittest.mock.MagicMock()
     mqtt_client_mock.publish.return_value.rc = return_code
     with caplog.at_level(logging.DEBUG):
-        _ActorMock(mac_address=mac_address).report_state(
+        _ActorMock(mac_address=mac_address, retry_count=3).report_state(
             state=state, mqtt_client=mqtt_client_mock
         )
     mqtt_client_mock.publish.assert_called_once_with(

+ 12 - 8
tests/test_switchbot_button_automator.py

@@ -25,9 +25,11 @@ import pytest
 import switchbot_mqtt
 
 # pylint: disable=protected-access
+# pylint: disable=too-many-arguments; these are tests, no API
 
 
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"])
+@pytest.mark.parametrize("retry_count", (3, 21))
 @pytest.mark.parametrize(
     ("message_payload", "action_name"),
     [
@@ -41,12 +43,14 @@ import switchbot_mqtt
 )
 @pytest.mark.parametrize("command_successful", [True, False])
 def test_execute_command(
-    caplog, mac_address, message_payload, action_name, command_successful
+    caplog, mac_address, retry_count, message_payload, action_name, command_successful
 ):
     with unittest.mock.patch(
         "switchbot.Switchbot.__init__", return_value=None
     ) as device_init_mock, caplog.at_level(logging.INFO):
-        actor = switchbot_mqtt._ButtonAutomator(mac_address=mac_address)
+        actor = switchbot_mqtt._ButtonAutomator(
+            mac_address=mac_address, retry_count=retry_count
+        )
         with unittest.mock.patch.object(
             actor, "report_state"
         ) as report_mock, unittest.mock.patch(
@@ -55,7 +59,7 @@ def test_execute_command(
             actor.execute_command(
                 mqtt_client="dummy", mqtt_message_payload=message_payload
             )
-    device_init_mock.assert_called_once_with(mac=mac_address)
+    device_init_mock.assert_called_once_with(mac=mac_address, retry_count=retry_count)
     action_mock.assert_called_once_with()
     if command_successful:
         assert caplog.record_tuples == [
@@ -89,12 +93,12 @@ def test_execute_command_invalid_payload(caplog, mac_address, message_payload):
     with unittest.mock.patch("switchbot.Switchbot") as device_mock, caplog.at_level(
         logging.INFO
     ):
-        actor = switchbot_mqtt._ButtonAutomator(mac_address=mac_address)
+        actor = switchbot_mqtt._ButtonAutomator(mac_address=mac_address, retry_count=21)
         with unittest.mock.patch.object(actor, "report_state") as report_mock:
             actor.execute_command(
                 mqtt_client="dummy", mqtt_message_payload=message_payload
             )
-    device_mock.assert_called_once_with(mac=mac_address)
+    device_mock.assert_called_once_with(mac=mac_address, retry_count=21)
     assert not device_mock().mock_calls  # no methods called
     report_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -121,9 +125,9 @@ def test_execute_command_bluetooth_error(caplog, mac_address, message_payload):
             "Failed to connect to peripheral {}, addr type: random".format(mac_address)
         ),
     ), caplog.at_level(logging.ERROR):
-        switchbot_mqtt._ButtonAutomator(mac_address=mac_address).execute_command(
-            mqtt_client="dummy", mqtt_message_payload=message_payload
-        )
+        switchbot_mqtt._ButtonAutomator(
+            mac_address=mac_address, retry_count=3
+        ).execute_command(mqtt_client="dummy", mqtt_message_payload=message_payload)
     assert caplog.record_tuples == [
         (
             "switchbot",

+ 12 - 8
tests/test_switchbot_curtain_motor.py

@@ -25,9 +25,11 @@ import pytest
 import switchbot_mqtt
 
 # pylint: disable=protected-access,
+# pylint: disable=too-many-arguments; these are tests, no API
 
 
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"])
+@pytest.mark.parametrize("retry_count", (2, 3))
 @pytest.mark.parametrize(
     ("message_payload", "action_name"),
     [
@@ -44,12 +46,14 @@ import switchbot_mqtt
 )
 @pytest.mark.parametrize("command_successful", [True, False])
 def test_execute_command(
-    caplog, mac_address, message_payload, action_name, command_successful
+    caplog, mac_address, retry_count, message_payload, action_name, command_successful
 ):
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain.__init__", return_value=None
     ) as device_init_mock, caplog.at_level(logging.INFO):
-        actor = switchbot_mqtt._CurtainMotor(mac_address=mac_address)
+        actor = switchbot_mqtt._CurtainMotor(
+            mac_address=mac_address, retry_count=retry_count
+        )
         with unittest.mock.patch.object(
             actor, "report_state"
         ) as report_mock, unittest.mock.patch(
@@ -58,7 +62,7 @@ def test_execute_command(
             actor.execute_command(
                 mqtt_client="dummy", mqtt_message_payload=message_payload
             )
-    device_init_mock.assert_called_once_with(mac=mac_address)
+    device_init_mock.assert_called_once_with(mac=mac_address, retry_count=retry_count)
     action_mock.assert_called_once_with()
     if command_successful:
         assert caplog.record_tuples == [
@@ -99,12 +103,12 @@ def test_execute_command_invalid_payload(caplog, mac_address, message_payload):
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain"
     ) as device_mock, caplog.at_level(logging.INFO):
-        actor = switchbot_mqtt._CurtainMotor(mac_address=mac_address)
+        actor = switchbot_mqtt._CurtainMotor(mac_address=mac_address, retry_count=7)
         with unittest.mock.patch.object(actor, "report_state") as report_mock:
             actor.execute_command(
                 mqtt_client="dummy", mqtt_message_payload=message_payload
             )
-    device_mock.assert_called_once_with(mac=mac_address)
+    device_mock.assert_called_once_with(mac=mac_address, retry_count=7)
     assert not device_mock().mock_calls  # no methods called
     report_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -133,9 +137,9 @@ def test_execute_command_bluetooth_error(caplog, mac_address, message_payload):
             "Failed to connect to peripheral {}, addr type: random".format(mac_address)
         ),
     ), caplog.at_level(logging.ERROR):
-        switchbot_mqtt._CurtainMotor(mac_address=mac_address).execute_command(
-            mqtt_client="dummy", mqtt_message_payload=message_payload
-        )
+        switchbot_mqtt._CurtainMotor(
+            mac_address=mac_address, retry_count=10
+        ).execute_command(mqtt_client="dummy", mqtt_message_payload=message_payload)
     assert caplog.record_tuples == [
         (
             "switchbot",