Skip to content

Commit 4bc068c

Browse files
authored
Add throttle for jobs run with ProcessorSampler (#6786)
* Limit the number of concurrent jobs when running run_batch. * Add throttle for concurrent jobs run by ProcessorSampler. * Add tests * Remove unused imports. * lint * format * moar lint
1 parent 170b20b commit 4bc068c

File tree

5 files changed

+88
-60
lines changed

5 files changed

+88
-60
lines changed

cirq-core/cirq/work/sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,15 +291,14 @@ async def run_batch_async(
291291
programs: Sequence['cirq.AbstractCircuit'],
292292
params_list: Optional[Sequence['cirq.Sweepable']] = None,
293293
repetitions: Union[int, Sequence[int]] = 1,
294-
limiter: duet.Limiter = duet.Limiter(10),
295294
) -> Sequence[Sequence['cirq.Result']]:
296295
"""Runs the supplied circuits asynchronously.
297296
298297
See docs for `cirq.Sampler.run_batch`.
299298
"""
300299
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
301300
return await duet.pstarmap_async(
302-
self.run_sweep_async, zip(programs, params_list, repetitions, [limiter] * len(programs))
301+
self.run_sweep_async, zip(programs, params_list, repetitions)
303302
)
304303

305304
def _normalize_batch_args(

cirq-google/cirq_google/engine/engine_job.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,12 @@ def delete(self) -> None:
262262
"""Deletes the job and result, if any."""
263263
self.context.client.delete_job(self.project_id, self.program_id, self.job_id)
264264

265-
async def results_async(
266-
self, limiter: duet.Limiter = duet.Limiter(None)
267-
) -> Sequence[EngineResult]:
265+
async def results_async(self) -> Sequence[EngineResult]:
268266
"""Returns the job results, blocking until the job is complete."""
269267
import cirq_google.engine.engine as engine_base
270268

271269
if self._results is None:
272-
result_response = await self._await_result_async(limiter)
270+
result_response = await self._await_result_async()
273271
result = result_response.result
274272
result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
275273
if (
@@ -288,9 +286,7 @@ async def results_async(
288286
raise ValueError(f'invalid result proto version: {result_type}')
289287
return self._results
290288

291-
async def _await_result_async(
292-
self, limiter: duet.Limiter = duet.Limiter(None)
293-
) -> quantum.QuantumResult:
289+
async def _await_result_async(self) -> quantum.QuantumResult:
294290
if self._job_result_future is not None:
295291
response = await self._job_result_future
296292
if isinstance(response, quantum.QuantumResult):
@@ -303,13 +299,12 @@ async def _await_result_async(
303299
'Internal error: The job response type is not recognized.'
304300
) # pragma: no cover
305301

306-
async with limiter:
307-
async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type]
308-
while True:
309-
job = await self._refresh_job_async()
310-
if job.execution_status.state in TERMINAL_STATES:
311-
break
312-
await duet.sleep(1)
302+
async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type]
303+
while True:
304+
job = await self._refresh_job_async()
305+
if job.execution_status.state in TERMINAL_STATES:
306+
break
307+
await duet.sleep(1)
313308
_raise_on_failure(job)
314309
response = await self.context.client.get_job_results_async(
315310
self.project_id, self.program_id, self.job_id

cirq-google/cirq_google/engine/engine_processor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ def engine(self) -> 'engine_base.Engine':
8787
return engine_base.Engine(self.project_id, context=self.context)
8888

8989
def get_sampler(
90-
self, run_name: str = "", device_config_name: str = "", snapshot_id: str = ""
90+
self,
91+
run_name: str = "",
92+
device_config_name: str = "",
93+
snapshot_id: str = "",
94+
max_concurrent_jobs: int = 10,
9195
) -> 'cg.engine.ProcessorSampler':
9296
"""Returns a sampler backed by the engine.
9397
Args:
@@ -100,6 +104,11 @@ def get_sampler(
100104
snapshot_id: A unique identifier for an immutable snapshot reference.
101105
A snapshot contains a collection of device configurations for the
102106
processor.
107+
max_concurrent_jobs: The maximum number of jobs to be sent
108+
simultaneously to the Engine. This client-side throttle can be
109+
used to proactively reduce load to the backends and avoid quota
110+
violations when pipelining circuit executions.
111+
103112
Returns:
104113
A `cirq.Sampler` instance (specifically a `engine_sampler.ProcessorSampler`
105114
that will send circuits to the Quantum Computing Service
@@ -127,6 +136,7 @@ def get_sampler(
127136
run_name=run_name,
128137
snapshot_id=snapshot_id,
129138
device_config_name=device_config_name,
139+
max_concurrent_jobs=max_concurrent_jobs,
130140
)
131141

132142
async def run_sweep_async(

cirq-google/cirq_google/engine/processor_sampler.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import cirq
1818
import duet
19-
from cirq_google.engine.engine_job import EngineJob
2019

2120
if TYPE_CHECKING:
2221
import cirq_google as cg
@@ -32,6 +31,7 @@ def __init__(
3231
run_name: str = "",
3332
snapshot_id: str = "",
3433
device_config_name: str = "",
34+
max_concurrent_jobs: int = 10,
3535
):
3636
"""Inits ProcessorSampler.
3737
@@ -48,6 +48,10 @@ def __init__(
4848
device_config_name: An identifier used to select the processor configuration
4949
utilized to run the job. A configuration identifies the set of
5050
available qubits, couplers, and supported gates in the processor.
51+
max_concurrent_jobs: The maximum number of jobs to be sent
52+
concurrently to the Engine. This client-side throttle can be
53+
used to proactively reduce load to the backends and avoid quota
54+
violations when pipelining circuit executions.
5155
5256
Raises:
5357
ValueError: If only one of `run_name` and `device_config_name` are specified.
@@ -59,28 +63,22 @@ def __init__(
5963
self._run_name = run_name
6064
self._snapshot_id = snapshot_id
6165
self._device_config_name = device_config_name
62-
self._result_limiter = duet.Limiter(None)
66+
self._concurrent_job_limiter = duet.Limiter(max_concurrent_jobs)
6367

6468
async def run_sweep_async(
65-
self,
66-
program: 'cirq.AbstractCircuit',
67-
params: cirq.Sweepable,
68-
repetitions: int = 1,
69-
limiter: duet.Limiter = duet.Limiter(None),
69+
self, program: 'cirq.AbstractCircuit', params: cirq.Sweepable, repetitions: int = 1
7070
) -> Sequence['cg.EngineResult']:
71-
job = await self._processor.run_sweep_async(
72-
program=program,
73-
params=params,
74-
repetitions=repetitions,
75-
run_name=self._run_name,
76-
snapshot_id=self._snapshot_id,
77-
device_config_name=self._device_config_name,
78-
)
79-
80-
if isinstance(job, EngineJob):
81-
return await job.results_async(limiter)
82-
83-
return await job.results_async()
71+
async with self._concurrent_job_limiter:
72+
job = await self._processor.run_sweep_async(
73+
program=program,
74+
params=params,
75+
repetitions=repetitions,
76+
run_name=self._run_name,
77+
snapshot_id=self._snapshot_id,
78+
device_config_name=self._device_config_name,
79+
)
80+
81+
return await job.results_async()
8482

8583
run_sweep = duet.sync(run_sweep_async)
8684

@@ -89,12 +87,10 @@ async def run_batch_async(
8987
programs: Sequence[cirq.AbstractCircuit],
9088
params_list: Optional[Sequence[cirq.Sweepable]] = None,
9189
repetitions: Union[int, Sequence[int]] = 1,
92-
limiter: duet.Limiter = duet.Limiter(10),
9390
) -> Sequence[Sequence['cg.EngineResult']]:
94-
self._result_limiter = limiter
9591
return cast(
9692
Sequence[Sequence['cg.EngineResult']],
97-
await super().run_batch_async(programs, params_list, repetitions, self._result_limiter),
93+
await super().run_batch_async(programs, params_list, repetitions),
9894
)
9995

10096
run_batch = duet.sync(run_batch_async)
@@ -114,7 +110,3 @@ def snapshot_id(self) -> str:
114110
@property
115111
def device_config_name(self) -> str:
116112
return self._device_config_name
117-
118-
@property
119-
def result_limiter(self) -> duet.Limiter:
120-
return self._result_limiter

cirq-google/cirq_google/engine/processor_sampler_test.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from unittest import mock
1616

1717
import pytest
18+
import duet
1819

1920
import cirq
2021
import cirq_google as cg
@@ -170,29 +171,60 @@ def test_run_batch_differing_repetitions():
170171
)
171172

172173

173-
def test_run_batch_receives_results_using_limiter():
174+
@duet.sync
175+
async def test_sampler_with_full_job_queue_blocks():
174176
processor = mock.create_autospec(AbstractProcessor)
175-
run_name = "RUN_NAME"
176-
device_config_name = "DEVICE_CONFIG_NAME"
177-
sampler = cg.ProcessorSampler(
178-
processor=processor, run_name=run_name, device_config_name=device_config_name
179-
)
177+
sampler = cg.ProcessorSampler(processor=processor, max_concurrent_jobs=2)
180178

181-
job = mock.AsyncMock(EngineJob)
179+
async def wait_forever(**kwargs):
180+
await duet.AwaitableFuture[None]()
181+
182+
processor.run_sweep_async.side_effect = wait_forever
183+
184+
a = cirq.LineQubit(0)
185+
circuit = cirq.Circuit(cirq.X(a))
182186

187+
with pytest.raises(TimeoutError):
188+
async with duet.timeout_scope(0.01):
189+
await sampler.run_batch_async([circuit] * 3)
190+
191+
assert processor.run_sweep_async.call_count == 2
192+
193+
194+
@duet.sync
195+
async def test_sampler_with_job_queue_availability_runs_all():
196+
processor = mock.create_autospec(AbstractProcessor)
197+
sampler = cg.ProcessorSampler(processor=processor, max_concurrent_jobs=3)
198+
199+
async def wait_forever(**kwargs):
200+
await duet.AwaitableFuture[None]()
201+
202+
processor.run_sweep_async.side_effect = wait_forever
203+
204+
a = cirq.LineQubit(0)
205+
circuit = cirq.Circuit(cirq.X(a))
206+
207+
with pytest.raises(TimeoutError):
208+
async with duet.timeout_scope(0.01):
209+
await sampler.run_batch_async([circuit] * 3)
210+
211+
assert processor.run_sweep_async.call_count == 3
212+
213+
214+
@duet.sync
215+
async def test_sampler_with_full_job_queue_unblocks_when_available():
216+
processor = mock.create_autospec(AbstractProcessor)
217+
sampler = cg.ProcessorSampler(processor=processor, max_concurrent_jobs=2)
218+
219+
job = mock.AsyncMock(EngineJob)
183220
processor.run_sweep_async.return_value = job
221+
184222
a = cirq.LineQubit(0)
185-
circuit1 = cirq.Circuit(cirq.X(a))
186-
circuit2 = cirq.Circuit(cirq.Y(a))
187-
params1 = [cirq.ParamResolver({'t': 1})]
188-
params2 = [cirq.ParamResolver({'t': 2})]
189-
circuits = [circuit1, circuit2]
190-
params_list = [params1, params2]
191-
repetitions = [1, 2]
223+
circuit = cirq.Circuit(cirq.X(a))
192224

193-
sampler.run_batch(circuits, params_list, repetitions)
225+
await sampler.run_batch_async([circuit] * 3)
194226

195-
job.results_async.assert_called_with(sampler.result_limiter)
227+
assert processor.run_sweep_async.call_count == 3
196228

197229

198230
def test_processor_sampler_processor_property():

0 commit comments

Comments
 (0)