Skip to content

Commit 3208b94

Browse files
authored
add snapshot_id to get_sampler interface on engine object (quantumlib#7005)
* add snapshot to get_sampler interface on engine object * backwards compat
1 parent 08b1efb commit 3208b94

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

cirq-google/cirq_google/engine/engine.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,11 @@ def get_processor(self, processor_id: str) -> engine_processor.EngineProcessor:
575575
return engine_processor.EngineProcessor(self.project_id, processor_id, self.context)
576576

577577
def get_sampler(
578-
self, processor_id: Union[str, List[str]], run_name: str = "", device_config_name: str = ""
578+
self,
579+
processor_id: Union[str, List[str]],
580+
run_name: str = "",
581+
device_config_name: str = "",
582+
snapshot_id: str = "",
579583
) -> 'cirq_google.ProcessorSampler':
580584
"""Returns a sampler backed by the engine.
581585
@@ -587,6 +591,8 @@ def get_sampler(
587591
device_config_name: An identifier used to select the processor configuration
588592
utilized to run the job. A configuration identifies the set of
589593
available qubits, couplers, and supported gates in the processor.
594+
snapshot_id: A unique identifier for an immutable snapshot reference. A
595+
snapshot contains a collection of device configurations for the processor.
590596
591597
Returns:
592598
A `cirq.Sampler` instance (specifically a `engine_sampler.ProcessorSampler`
@@ -603,7 +609,7 @@ def get_sampler(
603609
'you need to specify a list.'
604610
)
605611
return self.get_processor(processor_id).get_sampler(
606-
run_name=run_name, device_config_name=device_config_name
612+
run_name=run_name, device_config_name=device_config_name, snapshot_id=snapshot_id
607613
)
608614

609615

cirq-google/cirq_google/engine/engine_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,24 @@ def test_run_circuit_with_unary_rpcs(client):
281281
client().get_job_results_async.assert_called_once_with('proj', 'prog', 'job-id')
282282

283283

284+
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
285+
def test_engine_get_sampler_with_snapshot_id_passes_to_unary_rpc(client):
286+
setup_run_circuit_with_result_(client, _A_RESULT)
287+
engine = cg.Engine(
288+
project_id='proj',
289+
context=EngineContext(service_args={'client_info': 1}, enable_streaming=False),
290+
)
291+
sampler = engine.get_sampler('mysim', device_config_name="config", snapshot_id="123")
292+
_ = sampler.run_sweep(_CIRCUIT, params=[cirq.ParamResolver({'a': 1})])
293+
294+
kwargs = client().create_job_async.call_args_list[0].kwargs
295+
296+
# We care about asserting that the snapshot_id is correctly passed.
297+
assert kwargs["snapshot_id"] == "123"
298+
assert kwargs["run_name"] == ""
299+
assert kwargs["device_config_name"] == "config"
300+
301+
284302
@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
285303
def test_run_circuit_with_stream_rpcs_passes(client):
286304
setup_run_circuit_with_result_(client, _A_RESULT)

0 commit comments

Comments
 (0)