Browse Source

callbacks: add support for device passwords

https://github.com/fphammerle/switchbot-mqtt/issues/37#issue-930973360
Fabian Peter Hammerle 2 years ago
parent
commit
f9cc568450

+ 37 - 11
switchbot_mqtt/__init__.py

@@ -46,7 +46,16 @@ def _mac_address_valid(mac_address: str) -> bool:
     return _MAC_ADDRESS_REGEX.match(mac_address.lower()) is not None
 
 
-_MQTTCallbackUserdata = collections.namedtuple("_MQTTCallbackUserdata", ["retry_count"])
+class _MQTTCallbackUserdata:
+    # pylint: disable=too-few-public-methods; @dataclasses.dataclass when python_requires>=3.7
+    def __init__(
+        self, retry_count: int, device_passwords: typing.Dict[str, str]
+    ) -> None:
+        self.retry_count = retry_count
+        self.device_passwords = device_passwords
+
+    def __eq__(self, other: object) -> bool:
+        return isinstance(other, type(self)) and vars(self) == vars(other)
 
 
 class _MQTTControlledActor(abc.ABC):
@@ -54,7 +63,9 @@ class _MQTTControlledActor(abc.ABC):
     MQTT_STATE_TOPIC_LEVELS = NotImplemented  # type: typing.List[_MQTTTopicLevel]
 
     @abc.abstractmethod
-    def __init__(self, mac_address: str, retry_count: int) -> None:
+    def __init__(
+        self, mac_address: str, retry_count: int, password: typing.Optional[str]
+    ) -> None:
         self._mac_address = mac_address
 
     @abc.abstractmethod
@@ -93,7 +104,12 @@ class _MQTTControlledActor(abc.ABC):
         if not _mac_address_valid(mac_address):
             _LOGGER.warning("invalid mac address %s", mac_address)
             return
-        cls(mac_address=mac_address, retry_count=userdata.retry_count).execute_command(
+        actor = cls(
+            mac_address=mac_address,
+            retry_count=userdata.retry_count,
+            password=userdata.device_passwords.get(mac_address, None),
+        )
+        actor.execute_command(
             mqtt_message_payload=message.payload, mqtt_client=mqtt_client
         )
 
@@ -160,9 +176,15 @@ class _ButtonAutomator(_MQTTControlledActor):
         "state",
     ]
 
-    def __init__(self, mac_address: str, retry_count: int) -> None:
-        self._device = switchbot.Switchbot(mac=mac_address, retry_count=retry_count)
-        super().__init__(mac_address=mac_address, retry_count=retry_count)
+    def __init__(
+        self, mac_address: str, retry_count: int, password: typing.Optional[None]
+    ) -> None:
+        self._device = switchbot.Switchbot(
+            mac=mac_address, password=password, retry_count=retry_count
+        )
+        super().__init__(
+            mac_address=mac_address, retry_count=retry_count, password=password
+        )
 
     def execute_command(
         self, mqtt_message_payload: bytes, mqtt_client: paho.mqtt.client.Client
@@ -205,11 +227,15 @@ class _CurtainMotor(_MQTTControlledActor):
         "state",
     ]
 
-    def __init__(self, mac_address: str, retry_count: int) -> None:
+    def __init__(
+        self, mac_address: str, retry_count: int, password: typing.Optional[None]
+    ) -> None:
         self._device = switchbot.SwitchbotCurtain(
-            mac=mac_address, retry_count=retry_count
+            mac=mac_address, password=password, retry_count=retry_count
+        )
+        super().__init__(
+            mac_address=mac_address, retry_count=retry_count, password=password
         )
-        super().__init__(mac_address=mac_address, retry_count=retry_count)
 
     def execute_command(
         self, mqtt_message_payload: bytes, mqtt_client: paho.mqtt.client.Client
@@ -248,7 +274,7 @@ class _CurtainMotor(_MQTTControlledActor):
 
 def _mqtt_on_connect(
     mqtt_client: paho.mqtt.client.Client,
-    user_data: typing.Any,
+    userdata: _MQTTCallbackUserdata,
     flags: typing.Dict,
     return_code: int,
 ) -> None:
@@ -270,7 +296,7 @@ def _run(
 ) -> None:
     # https://pypi.org/project/paho-mqtt/
     mqtt_client = paho.mqtt.client.Client(
-        userdata=_MQTTCallbackUserdata(retry_count=retry_count)
+        userdata=_MQTTCallbackUserdata(retry_count=retry_count, device_passwords={})
     )
     mqtt_client.on_connect = _mqtt_on_connect
     _LOGGER.info("connecting to MQTT broker %s:%d", mqtt_host, mqtt_port)

+ 12 - 4
tests/test_actor_base.py

@@ -16,6 +16,8 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
+import typing
+
 import paho.mqtt.client
 import pytest
 
@@ -27,13 +29,19 @@ import switchbot_mqtt
 def test_abstract():
     with pytest.raises(TypeError, match=r"\babstract class\b"):
         # pylint: disable=abstract-class-instantiated
-        switchbot_mqtt._MQTTControlledActor(mac_address=None, retry_count=21)
+        switchbot_mqtt._MQTTControlledActor(
+            mac_address=None, retry_count=21, password=None
+        )
 
 
 def test_execute_command_abstract():
     class _ActorMock(switchbot_mqtt._MQTTControlledActor):
-        def __init__(self, mac_address: str, retry_count: int) -> None:
-            super().__init__(mac_address=mac_address, retry_count=retry_count)
+        def __init__(
+            self, mac_address: str, retry_count: int, password: typing.Optional[str]
+        ) -> None:
+            super().__init__(
+                mac_address=mac_address, retry_count=retry_count, password=password
+            )
 
         def execute_command(
             self, mqtt_message_payload: bytes, mqtt_client: paho.mqtt.client.Client
@@ -42,6 +50,6 @@ def test_execute_command_abstract():
                 mqtt_message_payload=mqtt_message_payload, mqtt_client=mqtt_client
             )
 
-    actor = _ActorMock(mac_address=None, retry_count=42)
+    actor = _ActorMock(mac_address=None, retry_count=42, password=None)
     with pytest.raises(NotImplementedError):
         actor.execute_command(mqtt_message_payload=b"dummy", mqtt_client="dummy")

+ 99 - 66
tests/test_mqtt.py

@@ -44,7 +44,9 @@ def test__run(caplog, mqtt_host, mqtt_port, retry_count):
             retry_count=retry_count,
         )
     mqtt_client_mock.assert_called_once_with(
-        userdata=switchbot_mqtt._MQTTCallbackUserdata(retry_count=retry_count)
+        userdata=switchbot_mqtt._MQTTCallbackUserdata(
+            retry_count=retry_count, device_passwords={}
+        )
     )
     assert not mqtt_client_mock().username_pw_set.called
     mqtt_client_mock().connect.assert_called_once_with(host=mqtt_host, port=mqtt_port)
@@ -104,7 +106,9 @@ def test__run_authentication(mqtt_host, mqtt_port, mqtt_username, mqtt_password)
             retry_count=7,
         )
     mqtt_client_mock.assert_called_once_with(
-        userdata=switchbot_mqtt._MQTTCallbackUserdata(retry_count=7)
+        userdata=switchbot_mqtt._MQTTCallbackUserdata(
+            retry_count=7, device_passwords={}
+        )
     )
     mqtt_client_mock().username_pw_set.assert_called_once_with(
         username=mqtt_username, password=mqtt_password
@@ -126,6 +130,23 @@ def test__run_authentication_missing_username(mqtt_host, mqtt_port, mqtt_passwor
             )
 
 
+def _mock_actor_class(
+    command_topic_levels: typing.List[switchbot_mqtt._MQTTTopicLevel],
+) -> typing.Type:
+    class _ActorMock(switchbot_mqtt._MQTTControlledActor):
+        MQTT_COMMAND_TOPIC_LEVELS = command_topic_levels
+
+        def __init__(self, mac_address, retry_count, password):
+            super().__init__(
+                mac_address=mac_address, retry_count=retry_count, password=password
+            )
+
+        def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
+            pass
+
+    return _ActorMock
+
+
 @pytest.mark.parametrize(
     ("command_topic_levels", "topic", "payload", "expected_mac_address"),
     [
@@ -182,28 +203,22 @@ def test__mqtt_command_callback(
     expected_mac_address: str,
     retry_count: int,
 ):
-    class _ActorMock(switchbot_mqtt._MQTTControlledActor):
-        MQTT_COMMAND_TOPIC_LEVELS = command_topic_levels
-
-        def __init__(self, mac_address, retry_count):
-            super().__init__(mac_address=mac_address, retry_count=retry_count)
-
-        def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
-            pass
-
+    ActorMock = _mock_actor_class(command_topic_levels)
     message = MQTTMessage(topic=topic)
     message.payload = payload
-    callback_userdata = switchbot_mqtt._MQTTCallbackUserdata(retry_count=retry_count)
+    callback_userdata = switchbot_mqtt._MQTTCallbackUserdata(
+        retry_count=retry_count, device_passwords={}
+    )
     with unittest.mock.patch.object(
-        _ActorMock, "__init__", return_value=None
+        ActorMock, "__init__", return_value=None
     ) as init_mock, unittest.mock.patch.object(
-        _ActorMock, "execute_command"
+        ActorMock, "execute_command"
     ) as execute_command_mock, caplog.at_level(
         logging.DEBUG
     ):
-        _ActorMock._mqtt_command_callback("client_dummy", callback_userdata, message)
+        ActorMock._mqtt_command_callback("client_dummy", callback_userdata, message)
     init_mock.assert_called_once_with(
-        mac_address=expected_mac_address, retry_count=retry_count
+        mac_address=expected_mac_address, retry_count=retry_count, password=None
     )
     execute_command_mock.assert_called_once_with(
         mqtt_client="client_dummy", mqtt_message_payload=payload
@@ -217,6 +232,45 @@ def test__mqtt_command_callback(
     ]
 
 
+@pytest.mark.parametrize(
+    ("mac_address", "expected_password"),
+    [
+        ("11:22:33:44:55:66", None),
+        ("aa:bb:cc:dd:ee:ff", "secret"),
+        ("11:22:33:dd:ee:ff", "äöü"),
+    ],
+)
+def test__mqtt_command_callback_password(mac_address, expected_password):
+    ActorMock = _mock_actor_class(
+        [
+            "switchbot",
+            switchbot_mqtt._MQTTTopicPlaceholder.MAC_ADDRESS,
+        ]
+    )
+    message = MQTTMessage(topic=b"switchbot/" + mac_address.encode())
+    message.payload = b"whatever"
+    callback_userdata = switchbot_mqtt._MQTTCallbackUserdata(
+        retry_count=3,
+        device_passwords={
+            "11:22:33:44:55:77": "test",
+            "aa:bb:cc:dd:ee:ff": "secret",
+            "11:22:33:dd:ee:ff": "äöü",
+        },
+    )
+    with unittest.mock.patch.object(
+        ActorMock, "__init__", return_value=None
+    ) as init_mock, unittest.mock.patch.object(
+        ActorMock, "execute_command"
+    ) as execute_command_mock:
+        ActorMock._mqtt_command_callback("client_dummy", callback_userdata, message)
+    init_mock.assert_called_once_with(
+        mac_address=mac_address, retry_count=3, password=expected_password
+    )
+    execute_command_mock.assert_called_once_with(
+        mqtt_client="client_dummy", mqtt_message_payload=b"whatever"
+    )
+
+
 @pytest.mark.parametrize(
     ("topic", "payload"),
     [
@@ -226,28 +280,22 @@ def test__mqtt_command_callback(
     ],
 )
 def test__mqtt_command_callback_unexpected_topic(caplog, topic: bytes, payload: bytes):
-    class _ActorMock(switchbot_mqtt._MQTTControlledActor):
-        MQTT_COMMAND_TOPIC_LEVELS = (
-            switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
-        )
-
-        def __init__(self, mac_address, retry_count):
-            super().__init__(mac_address=mac_address, retry_count=retry_count)
-
-        def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
-            pass
-
+    ActorMock = _mock_actor_class(
+        switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
+    )
     message = MQTTMessage(topic=topic)
     message.payload = payload
     with unittest.mock.patch.object(
-        _ActorMock, "__init__", return_value=None
+        ActorMock, "__init__", return_value=None
     ) as init_mock, unittest.mock.patch.object(
-        _ActorMock, "execute_command"
+        ActorMock, "execute_command"
     ) as execute_command_mock, caplog.at_level(
         logging.DEBUG
     ):
-        _ActorMock._mqtt_command_callback(
-            "client_dummy", switchbot_mqtt._MQTTCallbackUserdata(retry_count=3), message
+        ActorMock._mqtt_command_callback(
+            "client_dummy",
+            switchbot_mqtt._MQTTCallbackUserdata(retry_count=3, device_passwords={}),
+            message,
         )
     init_mock.assert_not_called()
     execute_command_mock.assert_not_called()
@@ -269,30 +317,22 @@ def test__mqtt_command_callback_unexpected_topic(caplog, topic: bytes, payload:
 def test__mqtt_command_callback_invalid_mac_address(
     caplog, mac_address: str, payload: bytes
 ):
-    class _ActorMock(switchbot_mqtt._MQTTControlledActor):
-        MQTT_COMMAND_TOPIC_LEVELS = (
-            switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
-        )
-
-        def __init__(self, mac_address, retry_count):
-            super().__init__(mac_address=mac_address, retry_count=retry_count)
-
-        def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
-            pass
-
+    ActorMock = _mock_actor_class(
+        switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
+    )
     topic = "homeassistant/switch/switchbot/{}/set".format(mac_address).encode()
     message = MQTTMessage(topic=topic)
     message.payload = payload
     with unittest.mock.patch.object(
-        _ActorMock, "__init__", return_value=None
+        ActorMock, "__init__", return_value=None
     ) as init_mock, unittest.mock.patch.object(
-        _ActorMock, "execute_command"
+        ActorMock, "execute_command"
     ) as execute_command_mock, caplog.at_level(
         logging.DEBUG
     ):
-        _ActorMock._mqtt_command_callback(
+        ActorMock._mqtt_command_callback(
             "client_dummy",
-            switchbot_mqtt._MQTTCallbackUserdata(retry_count=None),
+            switchbot_mqtt._MQTTCallbackUserdata(retry_count=3, device_passwords={}),
             message,
         )
     init_mock.assert_not_called()
@@ -316,30 +356,22 @@ def test__mqtt_command_callback_invalid_mac_address(
     [(b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", b"ON")],
 )
 def test__mqtt_command_callback_ignore_retained(caplog, topic: bytes, payload: bytes):
-    class _ActorMock(switchbot_mqtt._MQTTControlledActor):
-        MQTT_COMMAND_TOPIC_LEVELS = (
-            switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
-        )
-
-        def __init__(self, mac_address, retry_count):
-            super().__init__(mac_address=mac_address, retry_count=retry_count)
-
-        def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
-            pass
-
+    ActorMock = _mock_actor_class(
+        switchbot_mqtt._ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS
+    )
     message = MQTTMessage(topic=topic)
     message.payload = payload
     message.retain = True
     with unittest.mock.patch.object(
-        _ActorMock, "__init__", return_value=None
+        ActorMock, "__init__", return_value=None
     ) as init_mock, unittest.mock.patch.object(
-        _ActorMock, "execute_command"
+        ActorMock, "execute_command"
     ) as execute_command_mock, caplog.at_level(
         logging.DEBUG
     ):
-        _ActorMock._mqtt_command_callback(
+        ActorMock._mqtt_command_callback(
             "client_dummy",
-            switchbot_mqtt._MQTTCallbackUserdata(retry_count=None),
+            switchbot_mqtt._MQTTCallbackUserdata(retry_count=4, device_passwords={}),
             message,
         )
     init_mock.assert_not_called()
@@ -384,8 +416,10 @@ def test__report_state(
     class _ActorMock(switchbot_mqtt._MQTTControlledActor):
         MQTT_STATE_TOPIC_LEVELS = state_topic_levels
 
-        def __init__(self, mac_address, retry_count):
-            super().__init__(mac_address=mac_address, retry_count=retry_count)
+        def __init__(self, mac_address, retry_count, password):
+            super().__init__(
+                mac_address=mac_address, retry_count=retry_count, password=password
+            )
 
         def execute_command(self, mqtt_message_payload: bytes, mqtt_client: Client):
             pass
@@ -393,9 +427,8 @@ def test__report_state(
     mqtt_client_mock = unittest.mock.MagicMock()
     mqtt_client_mock.publish.return_value.rc = return_code
     with caplog.at_level(logging.DEBUG):
-        _ActorMock(mac_address=mac_address, retry_count=3).report_state(
-            state=state, mqtt_client=mqtt_client_mock
-        )
+        actor = _ActorMock(mac_address=mac_address, retry_count=3, password=None)
+        actor.report_state(state=state, mqtt_client=mqtt_client_mock)
     mqtt_client_mock.publish.assert_called_once_with(
         topic=expected_topic, payload=state, retain=True
     )

+ 17 - 6
tests/test_switchbot_button_automator.py

@@ -29,6 +29,7 @@ import switchbot_mqtt
 
 
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"])
+@pytest.mark.parametrize("password", (None, "secret"))
 @pytest.mark.parametrize("retry_count", (3, 21))
 @pytest.mark.parametrize(
     ("message_payload", "action_name"),
@@ -43,13 +44,19 @@ import switchbot_mqtt
 )
 @pytest.mark.parametrize("command_successful", [True, False])
 def test_execute_command(
-    caplog, mac_address, retry_count, message_payload, action_name, command_successful
+    caplog,
+    mac_address,
+    password,
+    retry_count,
+    message_payload,
+    action_name,
+    command_successful,
 ):
     with unittest.mock.patch(
         "switchbot.Switchbot.__init__", return_value=None
     ) as device_init_mock, caplog.at_level(logging.INFO):
         actor = switchbot_mqtt._ButtonAutomator(
-            mac_address=mac_address, retry_count=retry_count
+            mac_address=mac_address, retry_count=retry_count, password=password
         )
         with unittest.mock.patch.object(
             actor, "report_state"
@@ -59,7 +66,9 @@ def test_execute_command(
             actor.execute_command(
                 mqtt_client="dummy", mqtt_message_payload=message_payload
             )
-    device_init_mock.assert_called_once_with(mac=mac_address, retry_count=retry_count)
+    device_init_mock.assert_called_once_with(
+        mac=mac_address, password=password, retry_count=retry_count
+    )
     action_mock.assert_called_once_with()
     if command_successful:
         assert caplog.record_tuples == [
@@ -93,12 +102,14 @@ def test_execute_command_invalid_payload(caplog, mac_address, message_payload):
     with unittest.mock.patch("switchbot.Switchbot") as device_mock, caplog.at_level(
         logging.INFO
     ):
-        actor = switchbot_mqtt._ButtonAutomator(mac_address=mac_address, retry_count=21)
+        actor = switchbot_mqtt._ButtonAutomator(
+            mac_address=mac_address, retry_count=21, password=None
+        )
         with unittest.mock.patch.object(actor, "report_state") as report_mock:
             actor.execute_command(
                 mqtt_client="dummy", mqtt_message_payload=message_payload
             )
-    device_mock.assert_called_once_with(mac=mac_address, retry_count=21)
+    device_mock.assert_called_once_with(mac=mac_address, retry_count=21, password=None)
     assert not device_mock().mock_calls  # no methods called
     report_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -126,7 +137,7 @@ def test_execute_command_bluetooth_error(caplog, mac_address, message_payload):
         ),
     ), caplog.at_level(logging.ERROR):
         switchbot_mqtt._ButtonAutomator(
-            mac_address=mac_address, retry_count=3
+            mac_address=mac_address, retry_count=3, password=None
         ).execute_command(mqtt_client="dummy", mqtt_message_payload=message_payload)
     assert caplog.record_tuples == [
         (

+ 25 - 7
tests/test_switchbot_curtain_motor.py

@@ -29,6 +29,7 @@ import switchbot_mqtt
 
 
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"])
+@pytest.mark.parametrize("password", ["pa$$word", None])
 @pytest.mark.parametrize("retry_count", (2, 3))
 @pytest.mark.parametrize(
     ("message_payload", "action_name"),
@@ -46,13 +47,19 @@ import switchbot_mqtt
 )
 @pytest.mark.parametrize("command_successful", [True, False])
 def test_execute_command(
-    caplog, mac_address, retry_count, message_payload, action_name, command_successful
+    caplog,
+    mac_address,
+    password,
+    retry_count,
+    message_payload,
+    action_name,
+    command_successful,
 ):
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain.__init__", return_value=None
     ) as device_init_mock, caplog.at_level(logging.INFO):
         actor = switchbot_mqtt._CurtainMotor(
-            mac_address=mac_address, retry_count=retry_count
+            mac_address=mac_address, retry_count=retry_count, password=password
         )
         with unittest.mock.patch.object(
             actor, "report_state"
@@ -62,7 +69,11 @@ def test_execute_command(
             actor.execute_command(
                 mqtt_client="dummy", mqtt_message_payload=message_payload
             )
-    device_init_mock.assert_called_once_with(mac=mac_address, retry_count=retry_count)
+    device_init_mock.assert_called_once_with(
+        mac=mac_address,
+        password=password,
+        retry_count=retry_count,
+    )
     action_mock.assert_called_once_with()
     if command_successful:
         assert caplog.record_tuples == [
@@ -98,17 +109,24 @@ def test_execute_command(
 
 
 @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"])
+@pytest.mark.parametrize("password", ["secret"])
 @pytest.mark.parametrize("message_payload", [b"OEFFNEN", b""])
-def test_execute_command_invalid_payload(caplog, mac_address, message_payload):
+def test_execute_command_invalid_payload(
+    caplog, mac_address, password, message_payload
+):
     with unittest.mock.patch(
         "switchbot.SwitchbotCurtain"
     ) as device_mock, caplog.at_level(logging.INFO):
-        actor = switchbot_mqtt._CurtainMotor(mac_address=mac_address, retry_count=7)
+        actor = switchbot_mqtt._CurtainMotor(
+            mac_address=mac_address, retry_count=7, password=password
+        )
         with unittest.mock.patch.object(actor, "report_state") as report_mock:
             actor.execute_command(
                 mqtt_client="dummy", mqtt_message_payload=message_payload
             )
-    device_mock.assert_called_once_with(mac=mac_address, retry_count=7)
+    device_mock.assert_called_once_with(
+        mac=mac_address, password=password, retry_count=7
+    )
     assert not device_mock().mock_calls  # no methods called
     report_mock.assert_not_called()
     assert caplog.record_tuples == [
@@ -138,7 +156,7 @@ def test_execute_command_bluetooth_error(caplog, mac_address, message_payload):
         ),
     ), caplog.at_level(logging.ERROR):
         switchbot_mqtt._CurtainMotor(
-            mac_address=mac_address, retry_count=10
+            mac_address=mac_address, retry_count=10, password="secret"
         ).execute_command(mqtt_client="dummy", mqtt_message_payload=message_payload)
     assert caplog.record_tuples == [
         (