diff --git a/pybricksdev/connections/pybricks.py b/pybricksdev/connections/pybricks.py index b32fbf6..6b392db 100644 --- a/pybricksdev/connections/pybricks.py +++ b/pybricksdev/connections/pybricks.py @@ -837,10 +837,20 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None: raise ValueError("Response is required for USB") self._ep_out.write(bytes([PybricksUsbOutEpMessageType.COMMAND]) + data) - # FIXME: This needs to race with hub disconnect, and could also use a - # timeout, otherwise it blocks forever. Pyusb doesn't currently seem to - # have any disconnect callback. - reply = await self._response_queue.get() + + try: + # FIXME: race_disconnect() doesn't work properly for USB connections since + # pyusb doesn't provide a reliable way to detect disconnects. The disconnect + # event from the USB stack may not be received in time to cancel the wait + # operation. We should implement a proper USB disconnect detection mechanism. + reply = await asyncio.wait_for( + self.race_disconnect(self._response_queue.get()), + timeout=1.0, + ) + except asyncio.TimeoutError: + if self.connection_state_observable.value != ConnectionState.CONNECTED: + raise RuntimeError("Hub disconnected while waiting for response") + raise asyncio.TimeoutError("Timeout waiting for USB response") # REVISIT: could look up status error code and convert to string, # although BLE doesn't do that either. diff --git a/tests/connections/test_pybricks.py b/tests/connections/test_pybricks.py index bf4a391..a1f27ce 100644 --- a/tests/connections/test_pybricks.py +++ b/tests/connections/test_pybricks.py @@ -4,18 +4,21 @@ import contextlib import os import tempfile -from unittest.mock import AsyncMock, PropertyMock, patch +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch import pytest -from reactivex.subject import Subject +from reactivex.subject import BehaviorSubject, Subject +from pybricksdev.ble.pybricks import PYBRICKS_COMMAND_EVENT_UUID from pybricksdev.connections.pybricks import ( ConnectionState, HubCapabilityFlag, HubKind, PybricksHubBLE, + PybricksHubUSB, StatusFlag, ) +from pybricksdev.usb.pybricks import PybricksUsbOutEpMessageType class TestPybricksHub: @@ -180,3 +183,116 @@ async def test_run_modern_protocol(self): # Verify the expected calls were made hub.download_user_program.assert_called_once() hub.start_user_program.assert_called_once() + + +class TestPybricksHubUSB: + """Tests for the PybricksHubUSB class functionality.""" + + @pytest.mark.asyncio + async def test_pybricks_hub_usb_write_gatt_char_disconnect(self): + """Test write_gatt_char when a disconnect event occurs.""" + hub = PybricksHubUSB(MagicMock()) + + hub._ep_out = MagicMock() + # Simulate _response_queue.get() blocking indefinitely + hub._response_queue = AsyncMock() + hub._response_queue.get = AsyncMock(side_effect=asyncio.Event().wait) + + mock_observable = MagicMock( + spec=Subject + ) # Using Subject as a base for mock spec + disconnect_callback_handler = None + + def mock_subscribe_side_effect(on_next_callback, *args, **kwargs): + nonlocal disconnect_callback_handler + disconnect_callback_handler = on_next_callback + mock_subscription = MagicMock() + mock_subscription.dispose = MagicMock() + return mock_subscription + + mock_observable.subscribe = MagicMock(side_effect=mock_subscribe_side_effect) + type(hub.connection_state_observable).value = PropertyMock( + return_value=ConnectionState.CONNECTED + ) + hub.connection_state_observable = mock_observable + + async def trigger_disconnect_event(): + await asyncio.sleep(0.05) + assert ( + disconnect_callback_handler is not None + ), "Subscribe was not called by race_disconnect" + disconnect_callback_handler(ConnectionState.DISCONNECTED) + + with pytest.raises(RuntimeError, match="disconnected during operation"): + await asyncio.gather( + hub.write_gatt_char(PYBRICKS_COMMAND_EVENT_UUID, b"test_data", True), + trigger_disconnect_event(), + ) + + hub._ep_out.write.assert_called_once_with( + bytes([PybricksUsbOutEpMessageType.COMMAND]) + b"test_data" + ) + + @pytest.mark.asyncio + async def test_pybricks_hub_usb_write_gatt_char_timeout(self): + """Test write_gatt_char when a timeout occurs.""" + hub = PybricksHubUSB(MagicMock()) + + hub._ep_out = MagicMock() + hub._response_queue = AsyncMock() + # Make _response_queue.get() block indefinitely + hub._response_queue.get = AsyncMock(side_effect=asyncio.Event().wait) + + mock_observable = MagicMock(spec=BehaviorSubject) + mock_observable.value = ConnectionState.CONNECTED + hub.connection_state_observable = mock_observable + + # Simulate a timeout while the hub is still connected + with patch( + "asyncio.wait_for", side_effect=asyncio.TimeoutError("Test-induced timeout") + ): + with pytest.raises( + asyncio.TimeoutError, match="Timeout waiting for USB response" + ): + await hub.write_gatt_char( + PYBRICKS_COMMAND_EVENT_UUID, b"test_data", True + ) + + hub._ep_out.write.assert_called_once_with( + bytes([PybricksUsbOutEpMessageType.COMMAND]) + b"test_data" + ) + + @pytest.mark.asyncio + async def test_pybricks_hub_usb_write_gatt_char_timeout_disconnected(self): + """Test write_gatt_char when a timeout occurs and hub is already disconnected. + + This test documents the FIXME case where race_disconnect() doesn't work properly + for USB connections because pyusb doesn't provide reliable disconnect detection. + In this case, we might get a timeout while the hub is already disconnected, + but the disconnect event wasn't received in time to cancel the wait operation. + """ + hub = PybricksHubUSB(MagicMock()) + + hub._ep_out = MagicMock() + hub._response_queue = AsyncMock() + # Make _response_queue.get() block indefinitely + hub._response_queue.get = AsyncMock(side_effect=asyncio.Event().wait) + + mock_observable = MagicMock(spec=BehaviorSubject) + mock_observable.value = ConnectionState.DISCONNECTED + hub.connection_state_observable = mock_observable + + # Simulate a timeout while the hub is already disconnected + with patch( + "asyncio.wait_for", side_effect=asyncio.TimeoutError("Test-induced timeout") + ): + with pytest.raises( + RuntimeError, match="Hub disconnected while waiting for response" + ): + await hub.write_gatt_char( + PYBRICKS_COMMAND_EVENT_UUID, b"test_data", True + ) + + hub._ep_out.write.assert_called_once_with( + bytes([PybricksUsbOutEpMessageType.COMMAND]) + b"test_data" + )