Browse Source

fix: guard relay_switch get_basic_info against short responses (#509)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Bluetooth Devices Bot 1 ngày trước cách đây
mục cha
commit
7b7b8f1cac
2 tập tin đã thay đổi với 130 bổ sung0 xóa
  1. 24 0
      switchbot/devices/relay_switch.py
  2. 106 0
      tests/test_relay_switch.py

+ 24 - 0
switchbot/devices/relay_switch.py

@@ -136,6 +136,14 @@ class SwitchbotRelaySwitch(SwitchbotSequenceDevice, SwitchbotEncryptedDevice):
 
         if not (_data := await self._get_basic_info(COMMAND_GET_BASIC_INFO)):
             return None
+        if len(_data) < 17:
+            _LOGGER.warning(
+                "%s: Short basic-info response (%d bytes): %s",
+                self.name,
+                len(_data),
+                _data.hex(),
+            )
+            return None
         if not (
             _channel1_data := await self._get_basic_info(
                 COMMAND_GET_CHANNEL1_INFO.format(
@@ -144,6 +152,14 @@ class SwitchbotRelaySwitch(SwitchbotSequenceDevice, SwitchbotEncryptedDevice):
             )
         ):
             return None
+        if len(_channel1_data) < 15:
+            _LOGGER.warning(
+                "%s: Short channel1 response (%d bytes): %s",
+                self.name,
+                len(_channel1_data),
+                _channel1_data.hex(),
+            )
+            return None
 
         _LOGGER.debug(
             "on-off hex: %s, channel1_hex_data: %s", _data.hex(), _channel1_data.hex()
@@ -223,6 +239,14 @@ class SwitchbotRelaySwitch2PM(SwitchbotRelaySwitch):
             )
         ):
             return None
+        if len(_channel2_data) < 15:
+            _LOGGER.warning(
+                "%s: Short channel2 response (%d bytes): %s",
+                self.name,
+                len(_channel2_data),
+                _channel2_data.hex(),
+            )
+            return None
 
         _LOGGER.debug("channel2_hex_data: %s", _channel2_data.hex())
 

+ 106 - 0
tests/test_relay_switch.py

@@ -305,6 +305,112 @@ async def test_basic_info_exceptions_2PM(common_parametrize_2pm, info_data):
     assert info is None
 
 
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "info_data",
+    [
+        # Truncated basic_info (single byte from the wire — repro of issue #369)
+        {
+            "basic_info": b"\x02",
+            "channel1_info": b"\x01\x00\x00\x00\x00\x00\x00\x02\x99\x00\xe9\x00\x03\x00\x00",
+            "channel2_info": b"\x01\x00\x055\x00'<\x02\x9f\x00\xe9\x01,\x00F",
+        },
+        # Basic_info just below the 17-byte minimum
+        {
+            "basic_info": b"\x01\x98A\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+            "channel1_info": b"\x01\x00\x00\x00\x00\x00\x00\x02\x99\x00\xe9\x00\x03\x00\x00",
+            "channel2_info": b"\x01\x00\x055\x00'<\x02\x9f\x00\xe9\x01,\x00F",
+        },
+        # Truncated channel1_info (single byte — repro of issue #369 user_data crash)
+        {
+            "basic_info": b"\x01\x98A\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10",
+            "channel1_info": b"\x01",
+            "channel2_info": b"\x01\x00\x055\x00'<\x02\x9f\x00\xe9\x01,\x00F",
+        },
+        # Channel1_info just below the 15-byte minimum
+        {
+            "basic_info": b"\x01\x98A\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10",
+            "channel1_info": b"\x01\x00\x00\x00\x00\x00\x00\x02\x99\x00\xe9\x00\x03\x00",
+            "channel2_info": b"\x01\x00\x055\x00'<\x02\x9f\x00\xe9\x01,\x00F",
+        },
+        # Truncated channel2_info
+        {
+            "basic_info": b"\x01\x98A\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10",
+            "channel1_info": b"\x01\x00\x00\x00\x00\x00\x00\x02\x99\x00\xe9\x00\x03\x00\x00",
+            "channel2_info": b"\x01",
+        },
+    ],
+)
+async def test_get_basic_info_2PM_short_response(common_parametrize_2pm, info_data):
+    """
+    Truncated BLE responses must yield None instead of crashing.
+
+    Regression coverage for issue #369: a single-byte payload reaches
+    `_parse_common_data`/`_parse_user_data` and raises IndexError/ValueError.
+    """
+    device = create_device_for_command_testing(
+        common_parametrize_2pm["rawAdvData"], common_parametrize_2pm["model"]
+    )
+
+    device.get_current_time_and_start_time = MagicMock(
+        return_value=("683074d6", "682fba80")
+    )
+
+    async def mock_get_basic_info(arg):
+        if arg == relay_switch.COMMAND_GET_BASIC_INFO:
+            return info_data["basic_info"]
+        if arg == relay_switch.COMMAND_GET_CHANNEL1_INFO.format("683074d6", "682fba80"):
+            return info_data["channel1_info"]
+        if arg == relay_switch.COMMAND_GET_CHANNEL2_INFO.format("683074d6", "682fba80"):
+            return info_data["channel2_info"]
+        return None
+
+    device._get_basic_info = AsyncMock(side_effect=mock_get_basic_info)
+
+    info = await device.get_basic_info()
+
+    assert info is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    ("rawAdvData", "model"),
+    common_params,
+)
+@pytest.mark.parametrize(
+    "info_data",
+    [
+        {
+            "basic_info": b"\x02",
+            "channel1_info": b"\x01\x00\x00\x00\x00\x00\x00\x02\x99\x00\xe9\x00\x03\x00\x00",
+        },
+        {
+            "basic_info": b"\x01\x98A\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10",
+            "channel1_info": b"\x01",
+        },
+    ],
+)
+async def test_get_basic_info_short_response(rawAdvData, model, info_data):
+    """Truncated BLE responses on single-channel relay/garage/plug must yield None."""
+    device = create_device_for_command_testing(rawAdvData, model)
+    device.get_current_time_and_start_time = MagicMock(
+        return_value=("683074d6", "682fba80")
+    )
+
+    async def mock_get_basic_info(arg):
+        if arg == relay_switch.COMMAND_GET_BASIC_INFO:
+            return info_data["basic_info"]
+        if arg == relay_switch.COMMAND_GET_CHANNEL1_INFO.format("683074d6", "682fba80"):
+            return info_data["channel1_info"]
+        return None
+
+    device._get_basic_info = AsyncMock(side_effect=mock_get_basic_info)
+
+    info = await device.get_basic_info()
+
+    assert info is None
+
+
 @pytest.mark.asyncio
 async def test_get_parsed_data_2PM(common_parametrize_2pm):
     """Test get_parsed_data for 2PM devices."""