Browse Source

internal function `_run`: add parameter to enable/disable TLS

https://github.com/fphammerle/systemctl-mqtt/commit/292046b88d4042c472817d72add8736c49e22146

https://github.com/fphammerle/switchbot-mqtt/issues/76
Fabian Peter Hammerle 2 years ago
parent
commit
22ef4b6299
4 changed files with 31 additions and 0 deletions
  1. 3 0
      switchbot_mqtt/__init__.py
  2. 1 0
      switchbot_mqtt/_cli.py
  3. 4 0
      tests/test_cli.py
  4. 23 0
      tests/test_mqtt.py

+ 3 - 0
switchbot_mqtt/__init__.py

@@ -54,6 +54,7 @@ def _run(
     mqtt_port: int,
     mqtt_username: typing.Optional[str],
     mqtt_password: typing.Optional[str],
+    mqtt_disable_tls: bool,
     retry_count: int,
     device_passwords: typing.Dict[str, str],
     fetch_device_info: bool,
@@ -68,6 +69,8 @@ def _run(
     )
     mqtt_client.on_connect = _mqtt_on_connect
     _LOGGER.info("connecting to MQTT broker %s:%d", mqtt_host, mqtt_port)
+    if not mqtt_disable_tls:
+        mqtt_client.tls_set(ca_certs=None)  # enable tls trusting default system certs
     if mqtt_username:
         mqtt_client.username_pw_set(username=mqtt_username, password=mqtt_password)
     elif mqtt_password:

+ 1 - 0
switchbot_mqtt/_cli.py

@@ -110,6 +110,7 @@ def _main() -> None:
         mqtt_port=args.mqtt_port,
         mqtt_username=args.mqtt_username,
         mqtt_password=mqtt_password,
+        mqtt_disable_tls=True,
         retry_count=args.retry_count,
         device_passwords=device_passwords,
         fetch_device_info=args.fetch_device_info

+ 4 - 0
tests/test_cli.py

@@ -110,6 +110,7 @@ def test__main(
         mqtt_port=expected_mqtt_port,
         mqtt_username=expected_username,
         mqtt_password=expected_password,
+        mqtt_disable_tls=True,
         retry_count=expected_retry_count,
         device_passwords={},
         fetch_device_info=False,
@@ -153,6 +154,7 @@ def test__main_mqtt_password_file(
         mqtt_port=1883,
         mqtt_username="me",
         mqtt_password=expected_password,
+        mqtt_disable_tls=True,
         retry_count=3,
         device_passwords={},
         fetch_device_info=False,
@@ -214,6 +216,7 @@ def test__main_device_password_file(
         mqtt_port=1883,
         mqtt_username=None,
         mqtt_password=None,
+        mqtt_disable_tls=True,
         retry_count=3,
         device_passwords=device_passwords,
         fetch_device_info=False,
@@ -235,6 +238,7 @@ def test__main_fetch_device_info() -> None:
         mqtt_port=1883,
         mqtt_username=None,
         mqtt_password=None,
+        mqtt_disable_tls=True,
         retry_count=3,
         device_passwords={},
     )

+ 23 - 0
tests/test_mqtt.py

@@ -58,6 +58,7 @@ def test__run(
             mqtt_port=mqtt_port,
             mqtt_username=None,
             mqtt_password=None,
+            mqtt_disable_tls=False,
             retry_count=retry_count,
             device_passwords=device_passwords,
             fetch_device_info=fetch_device_info,
@@ -72,6 +73,7 @@ def test__run(
         fetch_device_info=fetch_device_info,
     )
     assert not mqtt_client_mock().username_pw_set.called
+    mqtt_client_mock().tls_set.assert_called_once_with(ca_certs=None)
     mqtt_client_mock().connect.assert_called_once_with(host=mqtt_host, port=mqtt_port)
     mqtt_client_mock().socket().getpeername.return_value = (mqtt_host, mqtt_port)
     with caplog.at_level(logging.DEBUG):
@@ -125,6 +127,25 @@ def test__run(
     ) in caplog.record_tuples
 
 
+@pytest.mark.parametrize("mqtt_disable_tls", [True, False])
+def test__run_tls(mqtt_disable_tls: bool) -> None:
+    with unittest.mock.patch("paho.mqtt.client.Client") as mqtt_client_mock:
+        switchbot_mqtt._run(
+            mqtt_host="mqtt.local",
+            mqtt_port=1234,
+            mqtt_username=None,
+            mqtt_password=None,
+            mqtt_disable_tls=mqtt_disable_tls,
+            retry_count=21,
+            device_passwords={},
+            fetch_device_info=True,
+        )
+    if mqtt_disable_tls:
+        mqtt_client_mock().tls_set.assert_not_called()
+    else:
+        mqtt_client_mock().tls_set.assert_called_once_with(ca_certs=None)
+
+
 @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"])
 @pytest.mark.parametrize("mqtt_port", [1833])
 @pytest.mark.parametrize("mqtt_username", ["me"])
@@ -141,6 +162,7 @@ def test__run_authentication(
             mqtt_port=mqtt_port,
             mqtt_username=mqtt_username,
             mqtt_password=mqtt_password,
+            mqtt_disable_tls=True,
             retry_count=7,
             device_passwords={},
             fetch_device_info=True,
@@ -168,6 +190,7 @@ def test__run_authentication_missing_username(
                 mqtt_port=mqtt_port,
                 mqtt_username=None,
                 mqtt_password=mqtt_password,
+                mqtt_disable_tls=True,
                 retry_count=3,
                 device_passwords={},
                 fetch_device_info=True,