diff --git a/cirq-google/cirq_google/engine/engine.py b/cirq-google/cirq_google/engine/engine.py index 2e0625b652d..f779c526b32 100644 --- a/cirq-google/cirq_google/engine/engine.py +++ b/cirq-google/cirq_google/engine/engine.py @@ -29,6 +29,7 @@ import string from typing import Dict, Iterable, List, Optional, Sequence, Set, TypeVar, Union, TYPE_CHECKING +import duet import google.auth from google.protobuf import any_pb2 @@ -493,7 +494,7 @@ def run_calibration( ) @util.deprecated_gate_set_parameter - def create_program( + async def create_program_async( self, program: cirq.AbstractCircuit, program_id: Optional[str] = None, @@ -524,7 +525,7 @@ def create_program( if not program_id: program_id = _make_random_id('prog-') - new_program_id, new_program = self.context.client.create_program( + new_program_id, new_program = await self.context.client.create_program_async( self.project_id, program_id, code=self.context._serialize_program(program, gate_set), @@ -536,8 +537,10 @@ def create_program( self.project_id, new_program_id, self.context, new_program ) + create_program = duet.sync(create_program_async) + @util.deprecated_gate_set_parameter - def create_batch_program( + async def create_batch_program_async( self, programs: Sequence[cirq.AbstractCircuit], program_id: Optional[str] = None, @@ -574,7 +577,7 @@ def create_batch_program( for program in programs: gate_set.serialize(program, msg=batch.programs.add()) - new_program_id, new_program = self.context.client.create_program( + new_program_id, new_program = await self.context.client.create_program_async( self.project_id, program_id, code=util.pack_any(batch), @@ -586,8 +589,10 @@ def create_batch_program( self.project_id, new_program_id, self.context, new_program, result_type=ResultType.Batch ) + create_batch_program = duet.sync(create_batch_program_async) + @util.deprecated_gate_set_parameter - def create_calibration_program( + async def create_calibration_program_async( self, layers: List['cirq_google.CalibrationLayer'], program_id: Optional[str] = None, @@ -632,7 +637,7 @@ def create_calibration_program( arg_to_proto(layer.args[arg], out=new_layer.args[arg]) gate_set.serialize(layer.program, msg=new_layer.layer) - new_program_id, new_program = self.context.client.create_program( + new_program_id, new_program = await self.context.client.create_program_async( self.project_id, program_id, code=util.pack_any(calibration), @@ -648,6 +653,8 @@ def create_calibration_program( result_type=ResultType.Calibration, ) + create_calibration_program = duet.sync(create_calibration_program_async) + def get_program(self, program_id: str) -> engine_program.EngineProgram: """Returns an EngineProgram for an existing Quantum Engine program. @@ -659,7 +666,7 @@ def get_program(self, program_id: str) -> engine_program.EngineProgram: """ return engine_program.EngineProgram(self.project_id, program_id, self.context) - def list_programs( + async def list_programs_async( self, created_before: Optional[Union[datetime.datetime, datetime.date]] = None, created_after: Optional[Union[datetime.datetime, datetime.date]] = None, @@ -681,7 +688,7 @@ def list_programs( """ client = self.context.client - response = client.list_programs( + response = await client.list_programs_async( self.project_id, created_before=created_before, created_after=created_after, @@ -697,7 +704,9 @@ def list_programs( for p in response ] - def list_jobs( + list_programs = duet.sync(list_programs_async) + + async def list_jobs_async( self, created_before: Optional[Union[datetime.datetime, datetime.date]] = None, created_after: Optional[Union[datetime.datetime, datetime.date]] = None, @@ -730,7 +739,7 @@ def list_jobs( `quantum.ExecutionStatus.State` enum for accepted values. """ client = self.context.client - response = client.list_jobs( + response = await client.list_jobs_async( self.project_id, None, created_before=created_before, @@ -749,7 +758,9 @@ def list_jobs( for j in response ] - def list_processors(self) -> List[engine_processor.EngineProcessor]: + list_jobs = duet.sync(list_jobs_async) + + async def list_processors_async(self) -> List[engine_processor.EngineProcessor]: """Returns a list of Processors that the user has visibility to in the current Engine project. The names of these processors are used to identify devices when scheduling jobs and gathering calibration metrics. @@ -758,7 +769,7 @@ def list_processors(self) -> List[engine_processor.EngineProcessor]: A list of EngineProcessors to access status, device and calibration information. """ - response = self.context.client.list_processors(self.project_id) + response = await self.context.client.list_processors_async(self.project_id) return [ engine_processor.EngineProcessor( self.project_id, engine_client._ids_from_processor_name(p.name)[1], self.context, p @@ -766,6 +777,8 @@ def list_processors(self) -> List[engine_processor.EngineProcessor]: for p in response ] + list_processors = duet.sync(list_processors_async) + def get_processor(self, processor_id: str) -> engine_processor.EngineProcessor: """Returns an EngineProcessor for a Quantum Engine processor. diff --git a/cirq-google/cirq_google/engine/engine_job.py b/cirq-google/cirq_google/engine/engine_job.py index dda063ec679..5ca191a84a4 100644 --- a/cirq-google/cirq_google/engine/engine_job.py +++ b/cirq-google/cirq_google/engine/engine_job.py @@ -13,10 +13,10 @@ # limitations under the License. """A helper for jobs that have been created on the Quantum Engine.""" import datetime -import time from typing import Dict, Iterator, List, Optional, overload, Sequence, Tuple, TYPE_CHECKING +import duet from google.protobuf import any_pb2 import cirq @@ -107,20 +107,25 @@ def program(self) -> 'engine_program.EngineProgram': return engine_program.EngineProgram(self.project_id, self.program_id, self.context) + async def _get_job_async(self, return_run_context: bool = False) -> quantum.QuantumJob: + return await self.context.client.get_job_async( + self.project_id, self.program_id, self.job_id, return_run_context + ) + + _get_job = duet.sync(_get_job_async) + def _inner_job(self) -> quantum.QuantumJob: if self._job is None: - self._job = self.context.client.get_job( - self.project_id, self.program_id, self.job_id, False - ) + self._job = self._get_job() return self._job - def _refresh_job(self) -> quantum.QuantumJob: + async def _refresh_job_async(self) -> quantum.QuantumJob: if self._job is None or self._job.execution_status.state not in TERMINAL_STATES: - self._job = self.context.client.get_job( - self.project_id, self.program_id, self.job_id, False - ) + self._job = await self._get_job_async() return self._job + _refresh_job = duet.sync(_refresh_job_async) + def create_time(self) -> 'datetime.datetime': """Returns when the job was created.""" return self._inner_job().create_time @@ -224,10 +229,7 @@ def get_repetitions_and_sweeps(self) -> Tuple[int, List[cirq.Sweep]]: A tuple of the repetition count and list of sweeps. """ if self._job is None or self._job.run_context is None: - self._job = self.context.client.get_job( - self.project_id, self.program_id, self.job_id, True - ) - + self._job = self._get_job(return_run_context=True) return _deserialize_run_context(self._job.run_context) def get_processor(self) -> 'Optional[engine_processor.EngineProcessor]': @@ -260,42 +262,26 @@ def delete(self) -> None: """Deletes the job and result, if any.""" self.context.client.delete_job(self.project_id, self.program_id, self.job_id) - def batched_results(self) -> Sequence[Sequence[EngineResult]]: + async def batched_results_async(self) -> Sequence[Sequence[EngineResult]]: """Returns the job results, blocking until the job is complete. This method is intended for batched jobs. Instead of flattening results into a single list, this will return a Sequence[Result] for each circuit in the batch. """ - self.results() + await self.results_async() if self._batched_results is None: raise ValueError('batched_results called for a non-batch result.') return self._batched_results - def _wait_for_result(self): - job = self._refresh_job() - total_seconds_waited = 0.0 - timeout = self.context.timeout - while True: - if timeout and total_seconds_waited >= timeout: - break - if job.execution_status.state in TERMINAL_STATES: - break - time.sleep(0.5) - total_seconds_waited += 0.5 - job = self._refresh_job() - _raise_on_failure(job) - response = self.context.client.get_job_results( - self.project_id, self.program_id, self.job_id - ) - return response.result + batched_results = duet.sync(batched_results_async) - def results(self) -> Sequence[EngineResult]: + async def results_async(self) -> Sequence[EngineResult]: """Returns the job results, blocking until the job is complete.""" import cirq_google.engine.engine as engine_base if self._results is None: - result = self._wait_for_result() + result = await self._await_result_async() result_type = result.type_url[len(engine_base.TYPE_PREFIX) :] if ( result_type == 'cirq.google.api.v1.Result' @@ -317,7 +303,22 @@ def results(self) -> Sequence[EngineResult]: raise ValueError(f'invalid result proto version: {result_type}') return self._results - def calibration_results(self) -> Sequence[CalibrationResult]: + results = duet.sync(results_async) + + async def _await_result_async(self) -> quantum.QuantumResult: + async with duet.timeout_scope(self.context.timeout): + while True: + job = await self._refresh_job_async() + if job.execution_status.state in TERMINAL_STATES: + break + await duet.sleep(0.5) + _raise_on_failure(job) + response = await self.context.client.get_job_results_async( + self.project_id, self.program_id, self.job_id + ) + return response.result + + async def calibration_results_async(self) -> Sequence[CalibrationResult]: """Returns the results of a run_calibration() call. This function will fail if any other type of results were returned @@ -326,7 +327,7 @@ def calibration_results(self) -> Sequence[CalibrationResult]: import cirq_google.engine.engine as engine_base if self._calibration_results is None: - result = self._wait_for_result() + result = await self._await_result_async() result_type = result.type_url[len(engine_base.TYPE_PREFIX) :] if result_type != 'cirq.google.api.v2.FocusedCalibrationResult': raise ValueError(f'Did not find calibration results, instead found: {result_type}') @@ -343,6 +344,8 @@ def calibration_results(self) -> Sequence[CalibrationResult]: self._calibration_results = cal_results return self._calibration_results + calibration_results = duet.sync(calibration_results_async) + def _get_job_results_v1(self, result: v1.program_pb2.Result) -> Sequence[EngineResult]: # coverage: ignore job_id = self.id() diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index 87301983491..8a6b45b3fda 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -25,6 +25,7 @@ from cirq_google.engine import util from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext +from cirq_google.engine.test_utils import uses_async_mock @pytest.fixture(scope='session', autouse=True) @@ -70,7 +71,8 @@ def test_create_time(): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_update_time(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) get_job.return_value = quantum.QuantumJob( @@ -82,7 +84,8 @@ def test_update_time(get_job): get_job.assert_called_once_with('a', 'b', 'steve', False) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_description(get_job): job = cg.EngineJob( 'a', 'b', 'steve', EngineContext(), _job=quantum.QuantumJob(description='hello') @@ -93,7 +96,8 @@ def test_description(get_job): get_job.assert_called_once_with('a', 'b', 'steve', False) -@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_description') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_description_async') def test_set_description(set_job_description): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) set_job_description.return_value = quantum.QuantumJob(description='world') @@ -112,7 +116,8 @@ def test_labels(): assert job.labels() == {'t': '1'} -@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_labels') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_labels_async') def test_set_labels(set_job_labels): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) set_job_labels.return_value = quantum.QuantumJob(labels={'a': '1', 'b': '1'}) @@ -124,7 +129,8 @@ def test_set_labels(set_job_labels): set_job_labels.assert_called_with('a', 'b', 'steve', {}) -@mock.patch('cirq_google.engine.engine_client.EngineClient.add_job_labels') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.add_job_labels_async') def test_add_labels(add_job_labels): job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=quantum.QuantumJob(labels={})) assert job.labels() == {} @@ -138,7 +144,8 @@ def test_add_labels(add_job_labels): add_job_labels.assert_called_with('a', 'b', 'steve', {'a': '2', 'b': '1'}) -@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_job_labels') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_job_labels_async') def test_remove_labels(remove_job_labels): job = cg.EngineJob( 'a', 'b', 'steve', EngineContext(), _job=quantum.QuantumJob(labels={'a': '1', 'b': '1'}) @@ -171,7 +178,8 @@ def test_processor_ids(): assert job.processor_ids() == ['p'] -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_status(get_job): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.RUNNING) @@ -216,7 +224,8 @@ def test_failure_with_no_error(): assert not job.failure() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) get_job.return_value = quantum.QuantumJob( @@ -230,7 +239,8 @@ def test_get_repetitions_and_sweeps(get_job): get_job.assert_called_once_with('a', 'b', 'steve', True) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps_v1(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) get_job.return_value = quantum.QuantumJob( @@ -244,7 +254,8 @@ def test_get_repetitions_and_sweeps_v1(get_job): job.get_repetitions_and_sweeps() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps_unsupported(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) get_job.return_value = quantum.QuantumJob( @@ -312,7 +323,8 @@ def test_get_calibration(get_calibration): get_calibration.assert_called_once_with('a', 'p', 123) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration_async') def test_calibration__with_no_calibration(get_calibration): job = cg.EngineJob( 'a', @@ -329,14 +341,16 @@ def test_calibration__with_no_calibration(get_calibration): assert not get_calibration.called -@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_job') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_job_async') def test_cancel(cancel_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) job.cancel() cancel_job.assert_called_once_with('a', 'b', 'steve') -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job_async') def test_delete(delete_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) job.delete() @@ -504,7 +518,8 @@ def test_delete(delete_job): UPDATE_TIME = datetime.datetime.now(tz=datetime.timezone.utc) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -520,7 +535,8 @@ def test_results(get_job_results): get_job_results.assert_called_once_with('a', 'b', 'steve') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_iter(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -535,7 +551,8 @@ def test_results_iter(get_job_results): assert results[1] == 'q=1010' -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_getitem(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -550,7 +567,8 @@ def test_results_getitem(get_job_results): _ = job[2] -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_batched_results(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -577,7 +595,8 @@ def test_batched_results(get_job_results): assert str(data[1][1]) == 'q=1001' -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_batched_results_not_a_batch(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -589,7 +608,8 @@ def test_batched_results_not_a_batch(get_job_results): job.batched_results() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_results(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -608,7 +628,8 @@ def test_calibration_results(get_job_results): assert data[0].metrics['theta'] == {(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)): [0.9999]} -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_defaults(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -628,7 +649,8 @@ def test_calibration_defaults(get_job_results): assert len(data[0].metrics) == 0 -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_results_not_a_calibration(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -640,7 +662,8 @@ def test_calibration_results_not_a_calibration(get_job_results): job.calibration_results() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_len(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -652,16 +675,16 @@ def test_results_len(get_job_results): assert len(job) == 2 -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') -@mock.patch('time.sleep', return_value=None) -def test_timeout(patched_time_sleep, get_job): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') +def test_timeout(get_job): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.RUNNING), update_time=UPDATE_TIME, ) get_job.return_value = qjob - job = cg.EngineJob('a', 'b', 'steve', EngineContext(timeout=500)) - with pytest.raises(RuntimeError, match='Timed out'): + job = cg.EngineJob('a', 'b', 'steve', EngineContext(timeout=0.1)) + with pytest.raises(TimeoutError): job.results() diff --git a/cirq-google/cirq_google/engine/engine_processor_test.py b/cirq-google/cirq_google/engine/engine_processor_test.py index 9cd45334222..bd832db0bf3 100644 --- a/cirq-google/cirq_google/engine/engine_processor_test.py +++ b/cirq-google/cirq_google/engine/engine_processor_test.py @@ -28,6 +28,7 @@ from cirq_google.api import v2 from cirq_google.engine import util from cirq_google.engine.engine import EngineContext +from cirq_google.engine.test_utils import uses_async_mock from cirq_google.cloud import quantum @@ -232,7 +233,8 @@ def test_engine_repr(): assert 'the-processor-id' in repr(processor) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') def test_health(get_processor): get_processor.return_value = quantum.QuantumProcessor(health=quantum.QuantumProcessor.Health.OK) processor = cg.EngineProcessor( @@ -244,7 +246,8 @@ def test_health(get_processor): assert processor.health() == 'OK' -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') def test_expected_down_time(get_processor): processor = cg.EngineProcessor('a', 'p', EngineContext(), _processor=quantum.QuantumProcessor()) assert not processor.expected_down_time() @@ -352,7 +355,8 @@ def test_get_missing_device(): _ = processor.get_device(gate_sets=[_GATE_SET]) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations_async') def test_list_calibrations(list_calibrations): list_calibrations.return_value = [_CALIBRATION] processor = cg.EngineProcessor('a', 'p', EngineContext()) @@ -391,7 +395,8 @@ def test_list_calibrations(list_calibrations): list_calibrations.assert_called_with('a', 'p', f'timestamp >= {today_midnight_timestamp}') -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations_async') def test_list_calibrations_old_params(list_calibrations): # Disable pylint warnings for use of deprecated parameters # pylint: disable=unexpected-keyword-arg @@ -409,7 +414,8 @@ def test_list_calibrations_old_params(list_calibrations): list_calibrations.assert_called_with('a', 'p', 'timestamp <= 1562600000') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration_async') def test_get_calibration(get_calibration): get_calibration.return_value = _CALIBRATION processor = cg.EngineProcessor('a', 'p', EngineContext()) @@ -419,7 +425,8 @@ def test_get_calibration(get_calibration): get_calibration.assert_called_once_with('a', 'p', 1562544000021) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration_async') def test_current_calibration(get_current_calibration): get_current_calibration.return_value = _CALIBRATION processor = cg.EngineProcessor('a', 'p', EngineContext()) @@ -429,7 +436,8 @@ def test_current_calibration(get_current_calibration): get_current_calibration.assert_called_once_with('a', 'p') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration_async') def test_missing_latest_calibration(get_current_calibration): get_current_calibration.return_value = None processor = cg.EngineProcessor('a', 'p', EngineContext()) @@ -437,7 +445,8 @@ def test_missing_latest_calibration(get_current_calibration): get_current_calibration.assert_called_once_with('a', 'p') -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_reservation_async') def test_create_reservation(create_reservation): name = 'projects/proj/processors/p0/reservations/psherman-wallaby-way' result = quantum.QuantumReservation( @@ -462,7 +471,8 @@ def test_create_reservation(create_reservation): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation_async') def test_delete_reservation(delete_reservation): name = 'projects/proj/processors/p0/reservations/rid' result = quantum.QuantumReservation( @@ -477,7 +487,8 @@ def test_delete_reservation(delete_reservation): delete_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation_async') def test_cancel_reservation(cancel_reservation): name = 'projects/proj/processors/p0/reservations/rid' result = quantum.QuantumReservation( @@ -492,8 +503,9 @@ def test_cancel_reservation(cancel_reservation): cancel_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation_async') def test_remove_reservation_delete(delete_reservation, get_reservation): name = 'projects/proj/processors/p0/reservations/rid' now = int(datetime.datetime.now().timestamp()) @@ -515,8 +527,9 @@ def test_remove_reservation_delete(delete_reservation, get_reservation): delete_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') -@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') +@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation_async') def test_remove_reservation_cancel(cancel_reservation, get_reservation): name = 'projects/proj/processors/p0/reservations/rid' now = int(datetime.datetime.now().timestamp()) @@ -538,7 +551,8 @@ def test_remove_reservation_cancel(cancel_reservation, get_reservation): cancel_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_remove_reservation_not_found(get_reservation): get_reservation.return_value = None processor = cg.EngineProcessor( @@ -551,8 +565,9 @@ def test_remove_reservation_not_found(get_reservation): processor.remove_reservation('rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_remove_reservation_failures(get_reservation, get_processor): name = 'projects/proj/processors/p0/reservations/rid' now = int(datetime.datetime.now().timestamp()) @@ -576,7 +591,8 @@ def test_remove_reservation_failures(get_reservation, get_processor): processor.remove_reservation('rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_get_reservation(get_reservation): name = 'projects/proj/processors/p0/reservations/rid' result = quantum.QuantumReservation( @@ -591,7 +607,8 @@ def test_get_reservation(get_reservation): get_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.update_reservation') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.update_reservation_async') def test_update_reservation(update_reservation): name = 'projects/proj/processors/p0/reservations/rid' result = quantum.QuantumReservation( @@ -610,7 +627,8 @@ def test_update_reservation(update_reservation): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations_async') def test_list_reservation(list_reservations): name = 'projects/proj/processors/p0/reservations/rid' results = [ @@ -640,7 +658,8 @@ def test_list_reservation(list_reservations): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') def test_get_schedule(list_time_slots): results = [ quantum.QuantumTimeSlot( @@ -675,7 +694,8 @@ def test_get_schedule(list_time_slots): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') def test_get_schedule_filter_by_time_slot(list_time_slots): results = [ quantum.QuantumTimeSlot( @@ -735,9 +755,10 @@ def wrapper(*args, **kwargs): return wrapper +@uses_async_mock @_allow_deprecated_freezegun @freezegun.freeze_time() -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') def test_get_schedule_time_filter_behavior(list_time_slots): list_time_slots.return_value = [] processor = cg.EngineProcessor('proj', 'p0', EngineContext()) @@ -779,9 +800,10 @@ def test_get_schedule_time_filter_behavior(list_time_slots): list_time_slots.assert_called_with('proj', 'p0', f'start_time < {utc_ts}') +@uses_async_mock @_allow_deprecated_freezegun @freezegun.freeze_time() -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations_async') def test_list_reservations_time_filter_behavior(list_reservations): list_reservations.return_value = [] processor = cg.EngineProcessor('proj', 'p0', EngineContext()) @@ -823,22 +845,25 @@ def test_list_reservations_time_filter_behavior(list_reservations): list_reservations.assert_called_with('proj', 'p0', f'start_time < {utc_ts}') -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_params(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( execution_status={'state': 'SUCCESS'}, update_time=_to_timestamp('2019-07-09T23:39:59Z') ) - client().get_job_results.return_value = quantum.QuantumResult(result=util.pack_any(_RESULTS_V2)) + client().get_job_results_async.return_value = quantum.QuantumResult( + result=util.pack_any(_RESULTS_V2) + ) processor = cg.EngineProcessor('a', 'p', EngineContext()) job = processor.run_sweep( @@ -855,36 +880,37 @@ def test_run_sweep_params(client): assert result.job_finished_time is not None assert results == cirq.read_json(json_text=cirq.to_json(results)) - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.run_context_pb2.RunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 for i, v in enumerate([1.0, 2.0]): assert sweeps[i].repetitions == 1 assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( execution_status={'state': 'SUCCESS'}, update_time=_to_timestamp('2019-07-09T23:39:59Z') ) - client().get_job_results.return_value = quantum.QuantumResult(result=_BATCH_RESULTS_V2) + client().get_job_results_async.return_value = quantum.QuantumResult(result=_BATCH_RESULTS_V2) processor = cg.EngineProcessor('a', 'p', EngineContext()) job = processor.run_batch( @@ -900,10 +926,10 @@ def test_run_batch(client): assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} for result in results: assert result.job_id == job.id() - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.batch_pb2.BatchRunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) assert len(run_context.run_contexts) == 2 for idx, rc in enumerate(run_context.run_contexts): sweeps = rc.parameter_sweeps @@ -913,24 +939,27 @@ def test_run_batch(client): assert sweeps[0].sweep.single_sweep.points.points == [1.0, 2.0] if idx == 1: assert sweeps[0].sweep.single_sweep.points.points == [3.0, 4.0] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_calibration(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob(execution_status={'state': 'SUCCESS'}) - client().get_job_results.return_value = quantum.QuantumResult(result=_CALIBRATION_RESULTS_V2) + client().get_job_async.return_value = quantum.QuantumJob(execution_status={'state': 'SUCCESS'}) + client().get_job_results_async.return_value = quantum.QuantumResult( + result=_CALIBRATION_RESULTS_V2 + ) q1 = cirq.GridQubit(2, 3) q2 = cirq.GridQubit(2, 4) @@ -952,7 +981,7 @@ def test_run_calibration(client): assert results[1].error_message == 'Second success' # assert label is correct - client().create_job.assert_called_once_with( + client().create_job_async.assert_called_once_with( project_id='proj', program_id='prog', job_id='job-id', @@ -963,22 +992,25 @@ def test_run_calibration(client): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_sampler(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( execution_status={'state': 'SUCCESS'}, update_time=_to_timestamp('2019-07-09T23:39:59Z') ) - client().get_job_results.return_value = quantum.QuantumResult(result=util.pack_any(_RESULTS_V2)) + client().get_job_results_async.return_value = quantum.QuantumResult( + result=util.pack_any(_RESULTS_V2) + ) processor = cg.EngineProcessor('proj', 'mysim', EngineContext()) sampler = processor.get_sampler() results = sampler.run_sweep( @@ -989,7 +1021,7 @@ def test_sampler(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - assert client().create_program.call_args[0][0] == 'proj' + assert client().create_program_async.call_args[0][0] == 'proj' def test_str(): diff --git a/cirq-google/cirq_google/engine/engine_program.py b/cirq-google/cirq_google/engine/engine_program.py index b7d2f28fd79..289f69d8df4 100644 --- a/cirq-google/cirq_google/engine/engine_program.py +++ b/cirq-google/cirq_google/engine/engine_program.py @@ -15,6 +15,7 @@ import datetime from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union +import duet from google.protobuf import any_pb2 import cirq @@ -63,7 +64,7 @@ def __init__( self._program = _program self.result_type = result_type - def run_sweep( + async def run_sweep_async( self, job_id: Optional[str] = None, params: cirq.Sweepable = None, @@ -105,7 +106,7 @@ def run_sweep( job_id = engine_base._make_random_id('job-') run_context = self.context._serialize_run_context(params, repetitions) - created_job_id, job = self.context.client.create_job( + created_job_id, job = await self.context.client.create_job_async( project_id=self.project_id, program_id=self.program_id, job_id=job_id, @@ -118,7 +119,9 @@ def run_sweep( self.project_id, self.program_id, created_job_id, self.context, job ) - def run_batch( + run_sweep = duet.sync(run_sweep_async) + + async def run_batch_async( self, job_id: Optional[str] = None, params_list: List[cirq.Sweepable] = None, @@ -179,7 +182,7 @@ def run_batch( (params, repetitions) for params in params_list ) - created_job_id, job = self.context.client.create_job( + created_job_id, job = await self.context.client.create_job_async( project_id=self.project_id, program_id=self.program_id, job_id=job_id, @@ -197,7 +200,9 @@ def run_batch( result_type=ResultType.Batch, ) - def run_calibration( + run_batch = duet.sync(run_batch_async) + + async def run_calibration_async( self, job_id: Optional[str] = None, processor_ids: Sequence[str] = (), @@ -241,7 +246,7 @@ def run_calibration( # on a run context in order to succeed validation. run_context = v2.run_context_pb2.RunContext() - created_job_id, job = self.context.client.create_job( + created_job_id, job = await self.context.client.create_job_async( project_id=self.project_id, program_id=self.program_id, job_id=job_id, @@ -259,7 +264,9 @@ def run_calibration( result_type=ResultType.Batch, ) - def run( + run_calibration = duet.sync(run_calibration_async) + + async def run_async( self, job_id: Optional[str] = None, param_resolver: cirq.ParamResolver = cirq.ParamResolver({}), @@ -286,16 +293,18 @@ def run( Returns: A single Result for this run. """ - return list( - self.run_sweep( - job_id=job_id, - params=[param_resolver], - repetitions=repetitions, - processor_ids=processor_ids, - description=description, - labels=labels, - ) - )[0] + job = await self.run_sweep_async( + job_id=job_id, + params=[param_resolver], + repetitions=repetitions, + processor_ids=processor_ids, + description=description, + labels=labels, + ) + results = await job.results_async() + return results[0] + + run = duet.sync(run_async) def engine(self) -> 'engine_base.Engine': """Returns the parent Engine object. @@ -318,13 +327,13 @@ def get_job(self, job_id: str) -> engine_job.EngineJob: """ return engine_job.EngineJob(self.project_id, self.program_id, job_id, self.context) - def list_jobs( + async def list_jobs_async( self, created_before: Optional[Union[datetime.datetime, datetime.date]] = None, created_after: Optional[Union[datetime.datetime, datetime.date]] = None, has_labels: Optional[Dict[str, str]] = None, execution_states: Optional[Set[quantum.ExecutionStatus.State]] = None, - ): + ) -> Sequence[engine_job.EngineJob]: """Returns the list of jobs for this program. Args: @@ -347,7 +356,7 @@ def list_jobs( `quantum.ExecutionStatus.State` enum for accepted values. """ client = self.context.client - response = client.list_jobs( + response = await client.list_jobs_async( self.project_id, self.program_id, created_before=created_before, @@ -366,6 +375,8 @@ def list_jobs( for j in response ] + list_jobs = duet.sync(list_jobs_async) + def _inner_program(self) -> quantum.QuantumProgram: if self._program is None: self._program = self.context.client.get_program(self.project_id, self.program_id, False) @@ -384,7 +395,7 @@ def description(self) -> str: """Returns the description of the program.""" return self._inner_program().description - def set_description(self, description: str) -> 'EngineProgram': + async def set_description_async(self, description: str) -> 'EngineProgram': """Sets the description of the program. Params: @@ -393,16 +404,18 @@ def set_description(self, description: str) -> 'EngineProgram': Returns: This EngineProgram. """ - self._program = self.context.client.set_program_description( + self._program = await self.context.client.set_program_description_async( self.project_id, self.program_id, description ) return self + set_description = duet.sync(set_description_async) + def labels(self) -> Dict[str, str]: """Returns the labels of the program.""" return self._inner_program().labels - def set_labels(self, labels: Dict[str, str]) -> 'EngineProgram': + async def set_labels_async(self, labels: Dict[str, str]) -> 'EngineProgram': """Sets (overwriting) the labels for a previously created quantum program. @@ -412,12 +425,14 @@ def set_labels(self, labels: Dict[str, str]) -> 'EngineProgram': Returns: This EngineProgram. """ - self._program = self.context.client.set_program_labels( + self._program = await self.context.client.set_program_labels_async( self.project_id, self.program_id, labels ) return self - def add_labels(self, labels: Dict[str, str]) -> 'EngineProgram': + set_labels = duet.sync(set_labels_async) + + async def add_labels_async(self, labels: Dict[str, str]) -> 'EngineProgram': """Adds new labels to a previously created quantum program. Params: @@ -426,12 +441,14 @@ def add_labels(self, labels: Dict[str, str]) -> 'EngineProgram': Returns: This EngineProgram. """ - self._program = self.context.client.add_program_labels( + self._program = await self.context.client.add_program_labels_async( self.project_id, self.program_id, labels ) return self - def remove_labels(self, keys: List[str]) -> 'EngineProgram': + add_labels = duet.sync(add_labels_async) + + async def remove_labels_async(self, keys: List[str]) -> 'EngineProgram': """Removes labels with given keys from the labels of a previously created quantum program. @@ -441,12 +458,14 @@ def remove_labels(self, keys: List[str]) -> 'EngineProgram': Returns: This EngineProgram. """ - self._program = self.context.client.remove_program_labels( + self._program = await self.context.client.remove_program_labels_async( self.project_id, self.program_id, keys ) return self - def get_circuit(self, program_num: Optional[int] = None) -> cirq.Circuit: + remove_labels = duet.sync(remove_labels_async) + + async def get_circuit_async(self, program_num: Optional[int] = None) -> cirq.Circuit: """Returns the cirq Circuit for the Quantum Engine program. This is only supported if the program was created with the V2 protos. @@ -459,10 +478,14 @@ def get_circuit(self, program_num: Optional[int] = None) -> cirq.Circuit: The program's cirq Circuit. """ if self._program is None or self._program.code is None: - self._program = self.context.client.get_program(self.project_id, self.program_id, True) + self._program = await self.context.client.get_program_async( + self.project_id, self.program_id, True + ) return _deserialize_program(self._program.code, program_num) - def batch_size(self) -> int: + get_circuit = duet.sync(get_circuit_async) + + async def batch_size_async(self) -> int: """Returns the number of programs in a batch program. Raises: @@ -475,7 +498,9 @@ def batch_size(self) -> int: import cirq_google.engine.engine as engine_base if self._program is None or self._program.code is None: - self._program = self.context.client.get_program(self.project_id, self.program_id, True) + self._program = await self.context.client.get_program_async( + self.project_id, self.program_id, True + ) code = self._program.code code_type = code.type_url[len(engine_base.TYPE_PREFIX) :] if code_type == 'cirq.google.api.v2.BatchProgram': @@ -483,20 +508,26 @@ def batch_size(self) -> int: return len(batch.programs) raise ValueError(f'Program was not a batch program but instead was of type {code_type}.') - def delete(self, delete_jobs: bool = False) -> None: + batch_size = duet.sync(batch_size_async) + + async def delete_async(self, delete_jobs: bool = False) -> None: """Deletes a previously created quantum program. Params: delete_jobs: If True will delete all the program's jobs, other this will fail if the program contains any jobs. """ - self.context.client.delete_program( + await self.context.client.delete_program_async( self.project_id, self.program_id, delete_jobs=delete_jobs ) - def delete_job(self, job_id: str) -> None: + delete = duet.sync(delete_async) + + async def delete_job_async(self, job_id: str) -> None: """Deletes the job and result, if any.""" - self.context.client.delete_job(self.project_id, self.program_id, job_id) + await self.context.client.delete_job_async(self.project_id, self.program_id, job_id) + + delete_job = duet.sync(delete_job_async) def __str__(self) -> str: return f'EngineProgram(project_id=\'{self.project_id}\', program_id=\'{self.program_id}\')' diff --git a/cirq-google/cirq_google/engine/engine_program_test.py b/cirq-google/cirq_google/engine/engine_program_test.py index a232df279c3..8bc54ae8e6c 100644 --- a/cirq-google/cirq_google/engine/engine_program_test.py +++ b/cirq-google/cirq_google/engine/engine_program_test.py @@ -27,6 +27,7 @@ from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext from cirq_google.engine.result_type import ResultType +from cirq_google.engine.test_utils import uses_async_mock _BATCH_PROGRAM_V2 = util.pack_any( @@ -163,9 +164,10 @@ ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_sweeps_delegation(create_job): - create_job.return_value = ('steve', quantum.QuantumJob()) +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_sweeps_delegation(create_job_async): + create_job_async.return_value = ('steve', quantum.QuantumJob()) program = cg.EngineProgram('my-proj', 'my-prog', EngineContext()) param_resolver = cirq.ParamResolver({}) job = program.run_sweep( @@ -174,9 +176,10 @@ def test_run_sweeps_delegation(create_job): assert job._job == quantum.QuantumJob() -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_batch_delegation(create_job): - create_job.return_value = ('kittens', quantum.QuantumJob()) +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_batch_delegation(create_job_async): + create_job_async.return_value = ('kittens', quantum.QuantumJob()) program = cg.EngineProgram('my-meow', 'my-meow', EngineContext(), result_type=ResultType.Batch) resolver_list = [cirq.Points('cats', [1.0, 2.0, 3.0]), cirq.Points('cats', [4.0, 5.0, 6.0])] job = program.run_batch( @@ -185,27 +188,29 @@ def test_run_batch_delegation(create_job): assert job._job == quantum.QuantumJob() -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_calibration_delegation(create_job): - create_job.return_value = ('dogs', quantum.QuantumJob()) +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_calibration_delegation(create_job_async): + create_job_async.return_value = ('dogs', quantum.QuantumJob()) program = cg.EngineProgram('woof', 'woof', EngineContext(), result_type=ResultType.Calibration) job = program.run_calibration(processor_ids=['lazydog']) assert job._job == quantum.QuantumJob() -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_calibration_no_processors(create_job): - create_job.return_value = ('dogs', quantum.QuantumJob()) +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_calibration_no_processors(create_job_async): + create_job_async.return_value = ('dogs', quantum.QuantumJob()) program = cg.EngineProgram('woof', 'woof', EngineContext(), result_type=ResultType.Calibration) with pytest.raises(ValueError, match='No processors specified'): _ = program.run_calibration(job_id='spot') -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_batch_no_sweeps(create_job): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_batch_no_sweeps(create_job_async): # Running with no sweeps is fine. Uses program's batch size to create # proper empty sweeps. - create_job.return_value = ('kittens', quantum.QuantumJob()) + create_job_async.return_value = ('kittens', quantum.QuantumJob()) program = cg.EngineProgram( 'my-meow', 'my-meow', @@ -216,7 +221,7 @@ def test_run_batch_no_sweeps(create_job): job = program.run_batch(job_id='steve', repetitions=10, processor_ids=['lazykitty']) assert job._job == quantum.QuantumJob() batch_run_context = v2.batch_pb2.BatchRunContext() - create_job.call_args[1]['run_context'].Unpack(batch_run_context) + create_job_async.call_args[1]['run_context'].Unpack(batch_run_context) assert len(batch_run_context.run_contexts) == 1 @@ -242,11 +247,12 @@ def test_run_in_batch_mode(): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_delegation(create_job, get_results): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_delegation(create_job_async, get_results_async): dt = datetime.datetime.now(tz=datetime.timezone.utc) - create_job.return_value = ( + create_job_async.return_value = ( 'steve', quantum.QuantumJob( name='projects/a/programs/b/jobs/steve', @@ -254,7 +260,7 @@ def test_run_delegation(create_job, get_results): update_time=dt, ), ) - get_results.return_value = quantum.QuantumResult( + get_results_async.return_value = quantum.QuantumResult( result=util.pack_any( Merge( """sweep_results: [{ @@ -297,15 +303,16 @@ def test_run_delegation(create_job, get_results): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs') -def test_list_jobs(list_jobs): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs_async') +def test_list_jobs(list_jobs_async): job1 = quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job1') job2 = quantum.QuantumJob(name='projects/otherproj/programs/prog1/jobs/job2') - list_jobs.return_value = [job1, job2] + list_jobs_async.return_value = [job1, job2] ctx = EngineContext() result = cg.EngineProgram(project_id='proj', program_id='prog1', context=ctx).list_jobs() - list_jobs.assert_called_once_with( + list_jobs_async.assert_called_once_with( 'proj', 'prog1', created_after=None, @@ -341,40 +348,43 @@ def test_create_time(): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_update_time(get_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_update_time(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram( + get_program_async.return_value = quantum.QuantumProgram( update_time=timestamp_pb2.Timestamp(seconds=1581515101) ) assert program.update_time() == datetime.datetime( 2020, 2, 12, 13, 45, 1, tzinfo=datetime.timezone.utc ) - get_program.assert_called_once_with('a', 'b', False) + get_program_async.assert_called_once_with('a', 'b', False) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_description(get_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_description(get_program_async): program = cg.EngineProgram( 'a', 'b', EngineContext(), _program=quantum.QuantumProgram(description='hello') ) assert program.description() == 'hello' - get_program.return_value = quantum.QuantumProgram(description='hello') + get_program_async.return_value = quantum.QuantumProgram(description='hello') assert cg.EngineProgram('a', 'b', EngineContext()).description() == 'hello' - get_program.assert_called_once_with('a', 'b', False) + get_program_async.assert_called_once_with('a', 'b', False) -@mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_description') -def test_set_description(set_program_description): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_description_async') +def test_set_description(set_program_description_async): program = cg.EngineProgram('a', 'b', EngineContext()) - set_program_description.return_value = quantum.QuantumProgram(description='world') + set_program_description_async.return_value = quantum.QuantumProgram(description='world') assert program.set_description('world').description() == 'world' - set_program_description.assert_called_with('a', 'b', 'world') + set_program_description_async.assert_called_with('a', 'b', 'world') - set_program_description.return_value = quantum.QuantumProgram(description='') + set_program_description_async.return_value = quantum.QuantumProgram(description='') assert program.set_description('').description() == '' - set_program_description.assert_called_with('a', 'b', '') + set_program_description_async.assert_called_with('a', 'b', '') def test_labels(): @@ -384,92 +394,101 @@ def test_labels(): assert program.labels() == {'t': '1'} -@mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_labels') -def test_set_labels(set_program_labels): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_labels_async') +def test_set_labels(set_program_labels_async): program = cg.EngineProgram('a', 'b', EngineContext()) - set_program_labels.return_value = quantum.QuantumProgram(labels={'a': '1', 'b': '1'}) + set_program_labels_async.return_value = quantum.QuantumProgram(labels={'a': '1', 'b': '1'}) assert program.set_labels({'a': '1', 'b': '1'}).labels() == {'a': '1', 'b': '1'} - set_program_labels.assert_called_with('a', 'b', {'a': '1', 'b': '1'}) + set_program_labels_async.assert_called_with('a', 'b', {'a': '1', 'b': '1'}) - set_program_labels.return_value = quantum.QuantumProgram() + set_program_labels_async.return_value = quantum.QuantumProgram() assert program.set_labels({}).labels() == {} - set_program_labels.assert_called_with('a', 'b', {}) + set_program_labels_async.assert_called_with('a', 'b', {}) -@mock.patch('cirq_google.engine.engine_client.EngineClient.add_program_labels') -def test_add_labels(add_program_labels): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.add_program_labels_async') +def test_add_labels(add_program_labels_async): program = cg.EngineProgram( 'a', 'b', EngineContext(), _program=quantum.QuantumProgram(labels={}) ) assert program.labels() == {} - add_program_labels.return_value = quantum.QuantumProgram(labels={'a': '1'}) + add_program_labels_async.return_value = quantum.QuantumProgram(labels={'a': '1'}) assert program.add_labels({'a': '1'}).labels() == {'a': '1'} - add_program_labels.assert_called_with('a', 'b', {'a': '1'}) + add_program_labels_async.assert_called_with('a', 'b', {'a': '1'}) - add_program_labels.return_value = quantum.QuantumProgram(labels={'a': '2', 'b': '1'}) + add_program_labels_async.return_value = quantum.QuantumProgram(labels={'a': '2', 'b': '1'}) assert program.add_labels({'a': '2', 'b': '1'}).labels() == {'a': '2', 'b': '1'} - add_program_labels.assert_called_with('a', 'b', {'a': '2', 'b': '1'}) + add_program_labels_async.assert_called_with('a', 'b', {'a': '2', 'b': '1'}) -@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_program_labels') -def test_remove_labels(remove_program_labels): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_program_labels_async') +def test_remove_labels(remove_program_labels_async): program = cg.EngineProgram( 'a', 'b', EngineContext(), _program=quantum.QuantumProgram(labels={'a': '1', 'b': '1'}) ) assert program.labels() == {'a': '1', 'b': '1'} - remove_program_labels.return_value = quantum.QuantumProgram(labels={'b': '1'}) + remove_program_labels_async.return_value = quantum.QuantumProgram(labels={'b': '1'}) assert program.remove_labels(['a']).labels() == {'b': '1'} - remove_program_labels.assert_called_with('a', 'b', ['a']) + remove_program_labels_async.assert_called_with('a', 'b', ['a']) - remove_program_labels.return_value = quantum.QuantumProgram(labels={}) + remove_program_labels_async.return_value = quantum.QuantumProgram(labels={}) assert program.remove_labels(['a', 'b', 'c']).labels() == {} - remove_program_labels.assert_called_with('a', 'b', ['a', 'b', 'c']) + remove_program_labels_async.assert_called_with('a', 'b', ['a', 'b', 'c']) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_v1(get_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_v1(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram(code=util.pack_any(v1.program_pb2.Program())) + get_program_async.return_value = quantum.QuantumProgram( + code=util.pack_any(v1.program_pb2.Program()) + ) with pytest.raises(ValueError, match='v1 Program is not supported'): program.get_circuit() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_v2(get_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_v2(get_program_async): circuit = cirq.Circuit( cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result') ) program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) assert program.get_circuit() == circuit - get_program.assert_called_once_with('a', 'b', True) + get_program_async.assert_called_once_with('a', 'b', True) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_batch(get_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_batch(get_program_async): circuit = cirq.Circuit( cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result') ) program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) with pytest.raises(ValueError, match='A program number must be specified'): program.get_circuit() with pytest.raises(ValueError, match='Only 1 in the batch but index 1 was specified'): program.get_circuit(1) assert program.get_circuit(0) == circuit - get_program.assert_called_once_with('a', 'b', True) + get_program_async.assert_called_once_with('a', 'b', True) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_batch_size(get_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_batch_size(get_program_async): # Has to fetch from engine if not _program specified. program = cg.EngineProgram('a', 'b', EngineContext(), result_type=ResultType.Batch) - get_program.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) assert program.batch_size() == 1 # If _program specified, uses that value. @@ -487,7 +506,7 @@ def test_get_batch_size(get_program): _ = program.batch_size() with pytest.raises(ValueError, match='cirq.google.api.v2.Program'): - get_program.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) program = cg.EngineProgram('a', 'b', EngineContext(), result_type=ResultType.Batch) _ = program.batch_size() @@ -500,10 +519,11 @@ def mock_grpc_client(): yield _fixture -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_v2_unknown_gateset(get_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_v2_unknown_gateset(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram( + get_program_async.return_value = quantum.QuantumProgram( code=util.pack_any( v2.program_pb2.Program(language=v2.program_pb2.Language(gate_set="BAD_GATESET")) ) @@ -513,10 +533,11 @@ def test_get_circuit_v2_unknown_gateset(get_program): program.get_circuit() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_unsupported_program_type(get_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_unsupported_program_type(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram( + get_program_async.return_value = quantum.QuantumProgram( code=any_pb2.Any(type_url='type.googleapis.com/unknown.proto') ) @@ -524,21 +545,23 @@ def test_get_circuit_unsupported_program_type(get_program): program.get_circuit() -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_program') -def test_delete(delete_program): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_program_async') +def test_delete(delete_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) program.delete() - delete_program.assert_called_with('a', 'b', delete_jobs=False) + delete_program_async.assert_called_with('a', 'b', delete_jobs=False) program.delete(delete_jobs=True) - delete_program.assert_called_with('a', 'b', delete_jobs=True) + delete_program_async.assert_called_with('a', 'b', delete_jobs=True) -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job') -def test_delete_jobs(delete_job): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job_async') +def test_delete_jobs(delete_job_async): program = cg.EngineProgram('a', 'b', EngineContext()) program.delete_job('c') - delete_job.assert_called_with('a', 'b', 'c') + delete_job_async.assert_called_with('a', 'b', 'c') def test_str(): diff --git a/cirq-google/cirq_google/engine/engine_test.py b/cirq-google/cirq_google/engine/engine_test.py index fac5ebf2495..d01e47f1db5 100644 --- a/cirq-google/cirq_google/engine/engine_test.py +++ b/cirq-google/cirq_google/engine/engine_test.py @@ -29,6 +29,7 @@ from cirq_google.engine import util from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext +from cirq_google.engine.test_utils import uses_async_mock _CIRCUIT = cirq.Circuit( @@ -276,9 +277,9 @@ def test_make_random_id(): @pytest.fixture(scope='session', autouse=True) -def mock_grpc_client(): +def mock_grpc_client_async(): with mock.patch( - 'cirq_google.engine.engine_client.quantum.QuantumEngineServiceClient' + 'cirq_google.engine.engine_client.quantum.QuantumEngineServiceAsyncClient', autospec=True ) as _fixture: yield _fixture @@ -299,7 +300,7 @@ def test_create_context(client): assert context.copy() == context -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_create_engine(client): with pytest.raises( ValueError, match='provide context or proto_version, service_args and verbose' @@ -321,40 +322,38 @@ def test_create_engine(client): ).context.proto_version == cg.engine.engine.ProtoVersion.V2 ) - assert client.called_with({'args': 'test'}, True) + client.assert_called_with({'args': 'test'}, True) def test_engine_str(): engine = cg.Engine( - 'proj', - proto_version=cg.engine.engine.ProtoVersion.V2, - service_args={'args': 'test'}, - verbose=True, + 'proj', proto_version=cg.engine.engine.ProtoVersion.V2, service_args={}, verbose=True ) - assert str(engine) == 'Engine(project_id=\'proj\')' + assert str(engine) == "Engine(project_id='proj')" _DT = datetime.datetime.now(tz=datetime.timezone.utc) def setup_run_circuit_with_result_(client, result): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( execution_status={'state': 'SUCCESS'}, update_time=_DT ) - client().get_job_results.return_value = quantum.QuantumResult(result=result) + client().get_job_results_async.return_value = quantum.QuantumResult(result=result) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit(client): setup_run_circuit_with_result_(client, _A_RESULT) @@ -367,27 +366,22 @@ def test_run_circuit(client): assert result.params.param_dict == {'a': 1} assert result.measurements == {'q': np.array([[0]], dtype='uint8')} client.assert_called_with(service_args={'client_info': 1}, verbose=None) - client.create_program.called_once_with() - client.create_job.called_once_with( - 'projects/project-id/programs/test', - quantum.QuantumJob( - name='projects/project-id/programs/test/jobs/job-id', - scheduling_config={ - 'priority': 50, - 'processor_selector': {'processor_names': ['projects/project-id/processors/mysim']}, - }, - run_context=util.pack_any( - v2.run_context_pb2.RunContext( - parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=1)] - ) - ), - update_time=_DT, + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once_with( + project_id='proj', + program_id='prog', + job_id='job-id', + processor_ids=['mysim'], + run_context=util.pack_any( + v2.run_context_pb2.RunContext( + parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=1)] + ) ), - False, + description=None, + labels=None, ) - - client.get_job.called_once_with('proj', 'prog') - client.get_job_result.called_once_with() + client().get_job_async.assert_called_once_with('proj', 'prog', 'job-id', False) + client().get_job_results_async.assert_called_once_with('proj', 'prog', 'job-id') def test_no_gate_set(): @@ -401,19 +395,20 @@ def test_unsupported_program_type(): engine.run(program="this isn't even the right type of thing!") -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_failed(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={ 'state': 'FAILURE', @@ -431,19 +426,20 @@ def test_run_circuit_failed(client): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_failed_missing_processor_name(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={ 'state': 'FAILURE', @@ -460,19 +456,20 @@ def test_run_circuit_failed_missing_processor_name(client): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_cancelled(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'CANCELLED'} ) @@ -483,29 +480,30 @@ def test_run_circuit_cancelled(client): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient') -@mock.patch('time.sleep', return_value=None) -def test_run_circuit_timeout(patched_time_sleep, client): - client().create_program.return_value = ( +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) +def test_run_circuit_timeout(client): + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'RUNNING'} ) - engine = cg.Engine(project_id='project-id', timeout=600) - with pytest.raises(RuntimeError, match='Timed out'): + engine = cg.Engine(project_id='project-id', timeout=1) + with pytest.raises(TimeoutError): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_params(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -520,21 +518,22 @@ def test_run_sweep_params(client): assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.run_context_pb2.RunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 for i, v in enumerate([1.0, 2.0]): assert sweeps[i].repetitions == 1 assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_multiple_times(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -542,10 +541,10 @@ def test_run_multiple_times(client): program = engine.create_program(program=_CIRCUIT) program.run(param_resolver=cirq.ParamResolver({'a': 1})) run_context = v2.run_context_pb2.RunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps1 = run_context.parameter_sweeps job2 = program.run_sweep(repetitions=2, params=cirq.Points('a', [3, 4])) - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps2 = run_context.parameter_sweeps results = job2.results() assert engine.context.proto_version == cg.engine.engine.ProtoVersion.V2 @@ -561,11 +560,12 @@ def test_run_multiple_times(client): assert len(sweeps2) == 1 assert sweeps2[0].repetitions == 2 assert sweeps2[0].sweep.single_sweep.points.points == [3, 4] - assert client().get_job.call_count == 2 - assert client().get_job_results.call_count == 2 + assert client().get_job_async.call_count == 2 + assert client().get_job_results_async.call_count == 2 -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_v2(client): setup_run_circuit_with_result_(client, _RESULTS_V2) @@ -577,19 +577,20 @@ def test_run_sweep_v2(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.run_context_pb2.RunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 1 assert sweeps[0].repetitions == 1 assert sweeps[0].sweep.single_sweep.points.points == [1, 2] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch(client): setup_run_circuit_with_result_(client, _BATCH_RESULTS_V2) @@ -606,10 +607,10 @@ def test_run_batch(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.batch_pb2.BatchRunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) assert len(run_context.run_contexts) == 2 for idx, rc in enumerate(run_context.run_contexts): sweeps = rc.parameter_sweeps @@ -619,11 +620,12 @@ def test_run_batch(client): assert sweeps[0].sweep.single_sweep.points.points == [1.0, 2.0] if idx == 1: assert sweeps[0].sweep.single_sweep.points.points == [3.0, 4.0] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch_no_params(client): # OK to run with no params, it should use empty sweeps for each # circuit. @@ -633,7 +635,7 @@ def test_run_batch_no_params(client): # Validate correct number of params have been created and that they # are empty sweeps. run_context = v2.batch_pb2.BatchRunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) assert len(run_context.run_contexts) == 2 for rc in run_context.run_contexts: sweeps = rc.parameter_sweeps @@ -672,7 +674,8 @@ def test_bad_sweep_proto(): program.run_sweep() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_calibration(client): setup_run_circuit_with_result_(client, _CALIBRATION_RESULTS_V2) @@ -696,7 +699,7 @@ def test_run_calibration(client): assert results[1].error_message == 'Second success' # assert label is correct - client().create_job.assert_called_once_with( + client().create_job_async.assert_called_once_with( project_id='proj', program_id='prog', job_id='job-id', @@ -725,7 +728,8 @@ def test_run_calibration_validation_fails(): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_bad_result_proto(client): result = any_pb2.Any() result.CopyFrom(_RESULTS_V2) @@ -752,14 +756,15 @@ def test_get_program(): assert cg.Engine(project_id='proj').get_program('prog').program_id == 'prog' -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_programs') -def test_list_programs(list_programs): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_programs_async') +def test_list_programs(list_programs_async): prog1 = quantum.QuantumProgram(name='projects/proj/programs/prog-YBGR48THF3JHERZW200804') prog2 = quantum.QuantumProgram(name='projects/otherproj/programs/prog-V3ZRTV6TTAFNTYJV200804') - list_programs.return_value = [prog1, prog2] + list_programs_async.return_value = [prog1, prog2] result = cg.Engine(project_id='proj').list_programs() - list_programs.assert_called_once_with( + list_programs_async.assert_called_once_with( 'proj', created_after=None, created_before=None, has_labels=None ) assert [(p.program_id, p.project_id, p._program) for p in result] == [ @@ -768,23 +773,25 @@ def test_list_programs(list_programs): ] -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_create_program(client): - client().create_program.return_value = ('prog', quantum.QuantumProgram()) + client().create_program_async.return_value = ('prog', quantum.QuantumProgram()) result = cg.Engine(project_id='proj').create_program(_CIRCUIT, 'prog') - client().create_program.assert_called_once() + client().create_program_async.assert_called_once() assert result.program_id == 'prog' -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs') -def test_list_jobs(list_jobs): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs_async') +def test_list_jobs(list_jobs_async): job1 = quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job1') job2 = quantum.QuantumJob(name='projects/proj/programs/prog2/jobs/job2') - list_jobs.return_value = [job1, job2] + list_jobs_async.return_value = [job1, job2] ctx = EngineContext() result = cg.Engine(project_id='proj', context=ctx).list_jobs() - list_jobs.assert_called_once_with( + list_jobs_async.assert_called_once_with( 'proj', None, created_after=None, @@ -798,14 +805,15 @@ def test_list_jobs(list_jobs): ] -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_processors') -def test_list_processors(list_processors): +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_processors_async') +def test_list_processors(list_processors_async): processor1 = quantum.QuantumProcessor(name='projects/proj/processors/xmonsim') processor2 = quantum.QuantumProcessor(name='projects/proj/processors/gmonsim') - list_processors.return_value = [processor1, processor2] + list_processors_async.return_value = [processor1, processor2] result = cg.Engine(project_id='proj').list_processors() - list_processors.assert_called_once_with('proj') + list_processors_async.assert_called_once_with('proj') assert [p.processor_id for p in result] == ['xmonsim', 'gmonsim'] @@ -813,7 +821,8 @@ def test_get_processor(): assert cg.Engine(project_id='proj').get_processor('xmonsim').processor_id == 'xmonsim' -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@uses_async_mock +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_sampler(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -827,7 +836,7 @@ def test_sampler(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - assert client().create_program.call_args[0][0] == 'proj' + assert client().create_program_async.call_args[0][0] == 'proj' with cirq.testing.assert_deprecated('sampler', deadline='1.0'): _ = engine.sampler(processor_id='tmp') @@ -839,12 +848,12 @@ def test_sampler(client): @mock.patch('cirq_google.cloud.quantum.QuantumEngineServiceClient') def test_get_engine(build): # Default project id present. - with mock.patch('google.auth.default', lambda: (None, 'project!')): + with mock.patch('google.auth.default', lambda *args, **kwargs: (None, 'project!')): eng = cirq_google.get_engine() assert eng.project_id == 'project!' # Nothing present. - with mock.patch('google.auth.default', lambda: (None, None)): + with mock.patch('google.auth.default', lambda *args, **kwargs: (None, None)): with pytest.raises(EnvironmentError, match='GOOGLE_CLOUD_PROJECT'): _ = cirq_google.get_engine() _ = cirq_google.get_engine('project!')