|
6 | 6 | from dataclasses import dataclass |
7 | 7 | from typing import AsyncIterator, Final, Literal, Protocol, overload |
8 | 8 |
|
| 9 | +from zmqtt._compat import defer_cancellation |
9 | 10 | from zmqtt.errors import MQTTDisconnectedError, MQTTTimeoutError |
10 | 11 | from zmqtt.log import get_logger |
11 | 12 | from zmqtt.packets.auth import Auth |
@@ -142,8 +143,7 @@ async def __aenter__(self) -> "Subscription": |
142 | 143 | async def __aexit__(self, *exc: object) -> None: |
143 | 144 | self._client._subscriptions.remove(self) |
144 | 145 | await self._cancel_relays() |
145 | | - task = asyncio.current_task() |
146 | | - being_cancelled = task is not None and task.cancelling() > 0 |
| 146 | + being_cancelled = isinstance(exc[1], asyncio.CancelledError) |
147 | 147 | if ( |
148 | 148 | not being_cancelled |
149 | 149 | and self._registered_filters |
@@ -243,21 +243,14 @@ async def __aenter__(self) -> "MQTTClient": |
243 | 243 | return self |
244 | 244 |
|
245 | 245 | async def __aexit__(self, *exc: object) -> None: |
246 | | - task = asyncio.current_task() |
247 | | - cancels = task.cancelling() if task else 0 |
248 | | - for _ in range(cancels): |
249 | | - task.uncancel() # type: ignore[union-attr] |
250 | | - try: |
| 246 | + async with defer_cancellation(): |
251 | 247 | if self._run_task is not None: |
252 | 248 | self._run_task.cancel() |
253 | 249 | await asyncio.gather(self._run_task, return_exceptions=True) |
254 | 250 | self._run_task = None |
255 | 251 | if self._protocol is not None: |
256 | 252 | await self._protocol.disconnect() |
257 | 253 | self._protocol = None |
258 | | - finally: |
259 | | - for _ in range(cancels): |
260 | | - task.cancel() # type: ignore[union-attr] |
261 | 254 |
|
262 | 255 | async def publish( |
263 | 256 | self, |
@@ -356,18 +349,19 @@ async def _run_loop(self) -> None: |
356 | 349 |
|
357 | 350 | while True: |
358 | 351 | assert self._protocol is not None |
| 352 | + protocol_run_task = asyncio.create_task(self._protocol.run()) |
359 | 353 | try: |
360 | 354 | # Run the protocol as a sub-task so _read_loop is live while we |
361 | 355 | # re-subscribe. For the first connection subs_to_restore is empty, |
362 | 356 | # so this collapses to the original "await protocol.run()" pattern. |
363 | | - async with asyncio.TaskGroup() as tg: |
364 | | - tg.create_task(self._protocol.run()) |
365 | | - if subs_to_restore: |
366 | | - await asyncio.sleep(0) # let _read_loop start |
367 | | - for sub in subs_to_restore: |
368 | | - await sub._reconnect(self._protocol) |
369 | | - subs_to_restore = [] |
370 | | - except* (MQTTDisconnectedError, MQTTTimeoutError): |
| 357 | + if subs_to_restore: |
| 358 | + await self._protocol.started_event.wait() |
| 359 | + for sub in subs_to_restore: |
| 360 | + await sub._reconnect(self._protocol) |
| 361 | + subs_to_restore = [] |
| 362 | + await protocol_run_task |
| 363 | + |
| 364 | + except (MQTTDisconnectedError, MQTTTimeoutError): |
371 | 365 | if not self._reconnect.enabled: |
372 | 366 | raise |
373 | 367 | # Close the dead transport to release the file descriptor. |
|
0 commit comments