From d0e9c16e8ea95cc92e14f4a6bda743a572181b0a Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Wed, 16 Feb 2022 00:03:24 -0800 Subject: [PATCH 1/9] wip - Add async support in EngineSampler --- cirq-core/requirements.txt | 2 +- cirq-google/cirq_google/engine/engine.py | 83 +++-- .../cirq_google/engine/engine_client.py | 306 +++++++++++++----- .../cirq_google/engine/engine_client_test.py | 112 ++++--- cirq-google/cirq_google/engine/engine_job.py | 73 +++-- .../cirq_google/engine/engine_job_test.py | 53 ++- .../engine/engine_processor_test.py | 118 +++---- .../cirq_google/engine/engine_program.py | 103 +++--- .../cirq_google/engine/engine_program_test.py | 172 +++++----- .../cirq_google/engine/engine_sampler.py | 20 +- .../cirq_google/engine/engine_sampler_test.py | 27 +- cirq-google/cirq_google/engine/engine_test.py | 191 ++++++----- 12 files changed, 737 insertions(+), 523 deletions(-) diff --git a/cirq-core/requirements.txt b/cirq-core/requirements.txt index 33166c7369b..9db8801eef1 100644 --- a/cirq-core/requirements.txt +++ b/cirq-core/requirements.txt @@ -3,7 +3,7 @@ # functools.cached_property was introduced in python 3.8 backports.cached_property~=1.0.1; python_version < '3.8' -duet~=0.2.0 +duet~=0.2.6 matplotlib~=3.0 networkx~=2.4 numpy~=1.16 diff --git a/cirq-google/cirq_google/engine/engine.py b/cirq-google/cirq_google/engine/engine.py index edf86ea33cd..f859bf1be16 100644 --- a/cirq-google/cirq_google/engine/engine.py +++ b/cirq-google/cirq_google/engine/engine.py @@ -29,6 +29,7 @@ import string from typing import Dict, Iterable, List, Optional, Sequence, Set, TypeVar, Union, TYPE_CHECKING +import duet import google.auth from google.protobuf import any_pb2 @@ -209,7 +210,7 @@ def __str__(self) -> str: return f'Engine(project_id={self.project_id!r})' @util.deprecated_gate_set_parameter - def run( + async def run_async( self, program: cirq.AbstractCircuit, program_id: Optional[str] = None, @@ -255,23 +256,25 @@ def run( Raises: ValueError: If no gate set is provided. """ - return list( - self.run_sweep( - program=program, - program_id=program_id, - job_id=job_id, - params=[param_resolver], - repetitions=repetitions, - processor_ids=processor_ids, - program_description=program_description, - program_labels=program_labels, - job_description=job_description, - job_labels=job_labels, - ) - )[0] + job = await self.run_sweep_async( + program=program, + program_id=program_id, + job_id=job_id, + params=[param_resolver], + repetitions=repetitions, + processor_ids=processor_ids, + program_description=program_description, + program_labels=program_labels, + job_description=job_description, + job_labels=job_labels, + ) + results = await job.results_async() + return results[0] + + run = duet.sync(run_async) @util.deprecated_gate_set_parameter - def run_sweep( + async def run_sweep_async( self, program: cirq.AbstractCircuit, program_id: Optional[str] = None, @@ -321,10 +324,10 @@ def run_sweep( Raises: ValueError: If no gate set is provided. """ - engine_program = self.create_program( + engine_program = await self.create_program_async( program, program_id, description=program_description, labels=program_labels ) - return engine_program.run_sweep( + return await engine_program.run_sweep_async( job_id=job_id, params=params, repetitions=repetitions, @@ -333,8 +336,10 @@ def run_sweep( labels=job_labels, ) + run_sweep = duet.sync(run_sweep_async) + @util.deprecated_gate_set_parameter - def run_batch( + async def run_batch_async( self, programs: Sequence[cirq.AbstractCircuit], program_id: Optional[str] = None, @@ -406,7 +411,7 @@ def run_batch( engine_program = self.create_batch_program( programs, program_id, description=program_description, labels=program_labels ) - return engine_program.run_batch( + return await engine_program.run_batch_async( job_id=job_id, params_list=params_list, repetitions=repetitions, @@ -415,6 +420,8 @@ def run_batch( labels=job_labels, ) + run_batch = duet.sync(run_batch_async) + @util.deprecated_gate_set_parameter def run_calibration( self, @@ -494,7 +501,7 @@ def run_calibration( ) @util.deprecated_gate_set_parameter - def create_program( + async def create_program_async( self, program: cirq.AbstractCircuit, program_id: Optional[str] = None, @@ -525,7 +532,7 @@ def create_program( if not program_id: program_id = _make_random_id('prog-') - new_program_id, new_program = self.context.client.create_program( + new_program_id, new_program = await self.context.client.create_program_async( self.project_id, program_id, code=self.context._serialize_program(program, gate_set), @@ -537,8 +544,10 @@ def create_program( self.project_id, new_program_id, self.context, new_program ) + create_program = duet.sync(create_program_async) + @util.deprecated_gate_set_parameter - def create_batch_program( + async def create_batch_program_async( self, programs: Sequence[cirq.AbstractCircuit], program_id: Optional[str] = None, @@ -575,7 +584,7 @@ def create_batch_program( for program in programs: gate_set.serialize(program, msg=batch.programs.add()) - new_program_id, new_program = self.context.client.create_program( + new_program_id, new_program = await self.context.client.create_program_async( self.project_id, program_id, code=util.pack_any(batch), @@ -587,8 +596,10 @@ def create_batch_program( self.project_id, new_program_id, self.context, new_program, result_type=ResultType.Batch ) + create_batch_program = duet.sync(create_batch_program_async) + @util.deprecated_gate_set_parameter - def create_calibration_program( + async def create_calibration_program_async( self, layers: List['cirq_google.CalibrationLayer'], program_id: Optional[str] = None, @@ -633,7 +644,7 @@ def create_calibration_program( arg_to_proto(layer.args[arg], out=new_layer.args[arg]) gate_set.serialize(layer.program, msg=new_layer.layer) - new_program_id, new_program = self.context.client.create_program( + new_program_id, new_program = await self.context.client.create_program_async( self.project_id, program_id, code=util.pack_any(calibration), @@ -649,6 +660,8 @@ def create_calibration_program( result_type=ResultType.Calibration, ) + create_calibration_program = duet.sync(create_calibration_program_async) + def get_program(self, program_id: str) -> engine_program.EngineProgram: """Returns an EngineProgram for an existing Quantum Engine program. @@ -660,7 +673,7 @@ def get_program(self, program_id: str) -> engine_program.EngineProgram: """ return engine_program.EngineProgram(self.project_id, program_id, self.context) - def list_programs( + async def list_programs_async( self, created_before: Optional[Union[datetime.datetime, datetime.date]] = None, created_after: Optional[Union[datetime.datetime, datetime.date]] = None, @@ -682,7 +695,7 @@ def list_programs( """ client = self.context.client - response = client.list_programs( + response = await client.list_programs_async( self.project_id, created_before=created_before, created_after=created_after, @@ -698,7 +711,9 @@ def list_programs( for p in response ] - def list_jobs( + list_programs = duet.sync(list_programs_async) + + async def list_jobs_async( self, created_before: Optional[Union[datetime.datetime, datetime.date]] = None, created_after: Optional[Union[datetime.datetime, datetime.date]] = None, @@ -731,7 +746,7 @@ def list_jobs( `quantum.ExecutionStatus.State` enum for accepted values. """ client = self.context.client - response = client.list_jobs( + response = await client.list_jobs_async( self.project_id, None, created_before=created_before, @@ -750,7 +765,9 @@ def list_jobs( for j in response ] - def list_processors(self) -> List[engine_processor.EngineProcessor]: + list_jobs = duet.sync(list_jobs_async) + + async def list_processors_async(self) -> List[engine_processor.EngineProcessor]: """Returns a list of Processors that the user has visibility to in the current Engine project. The names of these processors are used to identify devices when scheduling jobs and gathering calibration metrics. @@ -759,7 +776,7 @@ def list_processors(self) -> List[engine_processor.EngineProcessor]: A list of EngineProcessors to access status, device and calibration information. """ - response = self.context.client.list_processors(self.project_id) + response = await self.context.client.list_processors_async(self.project_id) return [ engine_processor.EngineProcessor( self.project_id, engine_client._ids_from_processor_name(p.name)[1], self.context, p @@ -767,6 +784,8 @@ def list_processors(self) -> List[engine_processor.EngineProcessor]: for p in response ] + list_processors = duet.sync(list_processors_async) + def get_processor(self, processor_id: str) -> engine_processor.EngineProcessor: """Returns an EngineProcessor for a Quantum Engine processor. diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 2a60534809c..7c002faa99f 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -12,17 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import datetime import sys -import time -from typing import Callable, Dict, List, Optional, Sequence, Set, TypeVar, Tuple, Union +import threading +from typing import ( + AsyncIterable, + Awaitable, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + TypeVar, + Tuple, + Union, +) import warnings +import duet import proto from google.api_core.exceptions import GoogleAPICallError, NotFound from google.protobuf import any_pb2, field_mask_pb2 from google.protobuf.timestamp_pb2 import Timestamp +from cirq._compat import cached_property from cirq_google.cloud import quantum _M = TypeVar('_M', bound=proto.Message) @@ -38,6 +53,26 @@ def __init__(self, message): RETRYABLE_ERROR_CODES = [500, 503] +class AsyncioExecutor: + def __init__(self) -> None: + loop_future: duet.AwaitableFuture[asyncio.AbstractEventLoop] = duet.AwaitableFuture() + thread = threading.Thread(target=asyncio.run, args=(self._main(loop_future),), daemon=True) + thread.start() + self.loop = loop_future.result() + + @staticmethod + async def _main(loop_future: duet.AwaitableFuture) -> None: + loop = asyncio.get_running_loop() + loop_future.set_result(loop) + while True: + await asyncio.sleep(1) + + def submit(self, func, *args, **kw) -> duet.AwaitableFuture: + """Dispatch the given function to be run in a duet coroutine.""" + future = asyncio.run_coroutine_threadsafe(func(*args, **kw), self.loop) + return duet.AwaitableFuture.wrap(future) + + class EngineClient: """Client for the Quantum Engine API that deals with the engine protos and the gRPC client but not cirq protos or objects. All users are likely better @@ -69,22 +104,41 @@ def __init__( if not service_args: service_args = {} - # Suppress warnings about using Application Default Credentials. - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - self.grpc_client = quantum.QuantumEngineServiceClient(**service_args) + self._service_args = service_args + self._executor = AsyncioExecutor() + + @cached_property + def grpc_client(self) -> quantum.QuantumEngineServiceAsyncClient: + """Creates an async grpc client for the quantum engine service.""" + + async def make_client(): + # Suppress warnings about using Application Default Credentials. + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + return quantum.QuantumEngineServiceAsyncClient(**self._service_args) + + return self._executor.submit(make_client).result() - def _make_request(self, func: Callable[[_M], _R], request: _M) -> _R: - return self._run_retry(lambda: func(request)) + async def _make_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R: + return await self._run_retry_async(lambda: func(request)) - def _run_retry(self, func: Callable[[], _R]) -> _R: + async def _make_list_request_async( + self, func: Callable[[_M], Awaitable[AsyncIterable[_R]]], request: _M + ) -> _R: + async def new_func(): + pager = await func(request) + return [item async for item in pager] + + return await self._run_retry_async(new_func) + + async def _run_retry_async(self, func: Callable[[], Awaitable[_R]]) -> _R: # Start with a 100ms retry delay with exponential backoff to # max_retry_delay_seconds current_delay = 0.1 while True: try: - return func() + return await self._executor.submit(func) except GoogleAPICallError as err: message = err.message # Raise RuntimeError for exceptions that are not retryable. @@ -97,10 +151,10 @@ def _run_retry(self, func: Callable[[], _R]) -> _R: if self.verbose: print(message, file=sys.stderr) print(f'Waiting {current_delay} seconds before retrying.', file=sys.stderr) - time.sleep(current_delay) + await duet.sleep(current_delay) current_delay *= 2 - def create_program( + async def create_program_async( self, project_id: str, program_id: Optional[str], @@ -132,10 +186,12 @@ def create_program( request = quantum.CreateQuantumProgramRequest( parent=parent_name, quantum_program=program, overwrite_existing_source_code=False ) - program = self._make_request(self.grpc_client.create_quantum_program, request) + program = await self._make_request_async(self.grpc_client.create_quantum_program, request) return _ids_from_program_name(program.name)[1], program - def get_program( + create_program = duet.sync(create_program_async) + + async def get_program_async( self, project_id: str, program_id: str, return_code: bool ) -> quantum.QuantumProgram: """Returns a previously created quantum program. @@ -148,9 +204,11 @@ def get_program( request = quantum.GetQuantumProgramRequest( name=_program_name_from_ids(project_id, program_id), return_code=return_code ) - return self._make_request(self.grpc_client.get_quantum_program, request) + return await self._make_request_async(self.grpc_client.get_quantum_program, request) + + get_program = duet.sync(get_program_async) - def list_programs( + async def list_programs_async( self, project_id: str, created_before: Optional[Union[datetime.datetime, datetime.date]] = None, @@ -187,9 +245,11 @@ def list_programs( request = quantum.ListQuantumProgramsRequest( parent=_project_name(project_id), filter=" AND ".join(filters) ) - return self._make_request(self.grpc_client.list_quantum_programs, request) + return await self._make_request_async(self.grpc_client.list_quantum_programs, request) - def set_program_description( + list_programs = duet.sync(list_programs_async) + + async def set_program_description_async( self, project_id: str, program_id: str, description: str ) -> quantum.QuantumProgram: """Sets the description for a previously created quantum program. @@ -210,9 +270,11 @@ def set_program_description( ), update_mask=field_mask_pb2.FieldMask(paths=['description']), ) - return self._make_request(self.grpc_client.update_quantum_program, request) + return await self._make_request_async(self.grpc_client.update_quantum_program, request) + + set_program_description = duet.sync(set_program_description_async) - def _set_program_labels( + async def _set_program_labels_async( self, project_id: str, program_id: str, labels: Dict[str, str], fingerprint: str ) -> quantum.QuantumProgram: program_resource_name = _program_name_from_ids(project_id, program_id) @@ -223,9 +285,9 @@ def _set_program_labels( ), update_mask=field_mask_pb2.FieldMask(paths=['labels']), ) - return self._make_request(self.grpc_client.update_quantum_program, request) + return await self._make_request_async(self.grpc_client.update_quantum_program, request) - def set_program_labels( + async def set_program_labels_async( self, project_id: str, program_id: str, labels: Dict[str, str] ) -> quantum.QuantumProgram: """Sets (overwriting) the labels for a previously created quantum @@ -240,9 +302,13 @@ def set_program_labels( The updated quantum program. """ program = self.get_program(project_id, program_id, False) - return self._set_program_labels(project_id, program_id, labels, program.label_fingerprint) + return await self._set_program_labels_async( + project_id, program_id, labels, program.label_fingerprint + ) - def add_program_labels( + set_program_labels = duet.sync(set_program_labels_async) + + async def add_program_labels_async( self, project_id: str, program_id: str, labels: Dict[str, str] ) -> quantum.QuantumProgram: """Adds new labels to a previously created quantum program. @@ -255,16 +321,20 @@ def add_program_labels( Returns: The updated quantum program. """ - program = self.get_program(project_id, program_id, False) + program = await self.get_program_async(project_id, program_id, False) old_labels = program.labels new_labels = dict(old_labels) new_labels.update(labels) if new_labels != old_labels: fingerprint = program.label_fingerprint - return self._set_program_labels(project_id, program_id, new_labels, fingerprint) + return await self._set_program_labels_async( + project_id, program_id, new_labels, fingerprint + ) return program - def remove_program_labels( + add_program_labels = duet.sync(add_program_labels_async) + + async def remove_program_labels_async( self, project_id: str, program_id: str, label_keys: List[str] ) -> quantum.QuantumProgram: """Removes labels with given keys from the labels of a previously @@ -278,17 +348,23 @@ def remove_program_labels( Returns: The updated quantum program. """ - program = self.get_program(project_id, program_id, False) + program = await self.get_program_async(project_id, program_id, False) old_labels = program.labels new_labels = dict(old_labels) for key in label_keys: new_labels.pop(key, None) if new_labels != old_labels: fingerprint = program.label_fingerprint - return self._set_program_labels(project_id, program_id, new_labels, fingerprint) + return await self._set_program_labels_async( + project_id, program_id, new_labels, fingerprint + ) return program - def delete_program(self, project_id: str, program_id: str, delete_jobs: bool = False) -> None: + remove_program_labels = duet.sync(remove_program_labels_async) + + async def delete_program_async( + self, project_id: str, program_id: str, delete_jobs: bool = False + ) -> None: """Deletes a previously created quantum program. Args: @@ -300,9 +376,11 @@ def delete_program(self, project_id: str, program_id: str, delete_jobs: bool = F request = quantum.DeleteQuantumProgramRequest( name=_program_name_from_ids(project_id, program_id), delete_jobs=delete_jobs ) - self._make_request(self.grpc_client.delete_quantum_program, request) + await self._make_request_async(self.grpc_client.delete_quantum_program, request) + + delete_program = duet.sync(delete_program_async) - def create_job( + async def create_job_async( self, project_id: str, program_id: str, @@ -360,10 +438,12 @@ def create_job( quantum_job=job, overwrite_existing_run_context=False, ) - job = self._make_request(self.grpc_client.create_quantum_job, request) + job = await self._make_request_async(self.grpc_client.create_quantum_job, request) return _ids_from_job_name(job.name)[2], job - def list_jobs( + create_job = duet.sync(create_job_async) + + async def list_jobs_async( self, project_id: str, program_id: Optional[str] = None, @@ -433,9 +513,11 @@ def list_jobs( program_id = "-" parent = _program_name_from_ids(project_id, program_id) request = quantum.ListQuantumJobsRequest(parent=parent, filter=" AND ".join(filters)) - return self._make_request(self.grpc_client.list_quantum_jobs, request) + return await self._make_request_async(self.grpc_client.list_quantum_jobs, request) + + list_jobs = duet.sync(list_jobs_async) - def get_job( + async def get_job_async( self, project_id: str, program_id: str, job_id: str, return_run_context: bool ) -> quantum.QuantumJob: """Returns a previously created job. @@ -452,9 +534,11 @@ def get_job( name=_job_name_from_ids(project_id, program_id, job_id), return_run_context=return_run_context, ) - return self._make_request(self.grpc_client.get_quantum_job, request) + return await self._make_request_async(self.grpc_client.get_quantum_job, request) - def set_job_description( + get_job = duet.sync(get_job_async) + + async def set_job_description_async( self, project_id: str, program_id: str, job_id: str, description: str ) -> quantum.QuantumJob: """Sets the description for a previously created quantum job. @@ -474,9 +558,11 @@ def set_job_description( quantum_job=quantum.QuantumJob(name=job_resource_name, description=description), update_mask=field_mask_pb2.FieldMask(paths=['description']), ) - return self._make_request(self.grpc_client.update_quantum_job, request) + return await self._make_request_async(self.grpc_client.update_quantum_job, request) + + set_job_description = duet.sync(set_job_description_async) - def _set_job_labels( + async def _set_job_labels_async( self, project_id: str, program_id: str, @@ -492,9 +578,9 @@ def _set_job_labels( ), update_mask=field_mask_pb2.FieldMask(paths=['labels']), ) - return self._make_request(self.grpc_client.update_quantum_job, request) + return await self._make_request_async(self.grpc_client.update_quantum_job, request) - def set_job_labels( + async def set_job_labels_async( self, project_id: str, program_id: str, job_id: str, labels: Dict[str, str] ) -> quantum.QuantumJob: """Sets (overwriting) the labels for a previously created quantum job. @@ -508,10 +594,14 @@ def set_job_labels( Returns: The updated quantum job. """ - job = self.get_job(project_id, program_id, job_id, False) - return self._set_job_labels(project_id, program_id, job_id, labels, job.label_fingerprint) + job = await self.get_job_async(project_id, program_id, job_id, False) + return await self._set_job_labels_async( + project_id, program_id, job_id, labels, job.label_fingerprint + ) - def add_job_labels( + set_job_labels = duet.sync(set_job_labels_async) + + async def add_job_labels_async( self, project_id: str, program_id: str, job_id: str, labels: Dict[str, str] ) -> quantum.QuantumJob: """Adds new labels to a previously created quantum job. @@ -525,16 +615,20 @@ def add_job_labels( Returns: The updated quantum job. """ - job = self.get_job(project_id, program_id, job_id, False) + job = await self.get_job_async(project_id, program_id, job_id, False) old_labels = job.labels new_labels = dict(old_labels) new_labels.update(labels) if new_labels != old_labels: fingerprint = job.label_fingerprint - return self._set_job_labels(project_id, program_id, job_id, new_labels, fingerprint) + return await self._set_job_labels_async( + project_id, program_id, job_id, new_labels, fingerprint + ) return job - def remove_job_labels( + add_job_labels = duet.sync(add_job_labels_async) + + async def remove_job_labels_async( self, project_id: str, program_id: str, job_id: str, label_keys: List[str] ) -> quantum.QuantumJob: """Removes labels with given keys from the labels of a previously @@ -549,17 +643,21 @@ def remove_job_labels( Returns: The updated quantum job. """ - job = self.get_job(project_id, program_id, job_id, False) + job = await self.get_job_async(project_id, program_id, job_id, False) old_labels = job.labels new_labels = dict(old_labels) for key in label_keys: new_labels.pop(key, None) if new_labels != old_labels: fingerprint = job.label_fingerprint - return self._set_job_labels(project_id, program_id, job_id, new_labels, fingerprint) + return await self._set_job_labels_async( + project_id, program_id, job_id, new_labels, fingerprint + ) return job - def delete_job(self, project_id: str, program_id: str, job_id: str) -> None: + remove_job_labels = duet.sync(remove_job_labels_async) + + async def delete_job_async(self, project_id: str, program_id: str, job_id: str) -> None: """Deletes a previously created quantum job. Args: @@ -570,9 +668,11 @@ def delete_job(self, project_id: str, program_id: str, job_id: str) -> None: request = quantum.DeleteQuantumJobRequest( name=_job_name_from_ids(project_id, program_id, job_id) ) - self._make_request(self.grpc_client.delete_quantum_job, request) + await self._make_request_async(self.grpc_client.delete_quantum_job, request) + + delete_job = duet.sync(delete_job_async) - def cancel_job(self, project_id: str, program_id: str, job_id: str) -> None: + async def cancel_job_async(self, project_id: str, program_id: str, job_id: str) -> None: """Cancels the given job. Args: @@ -583,9 +683,11 @@ def cancel_job(self, project_id: str, program_id: str, job_id: str) -> None: request = quantum.CancelQuantumJobRequest( name=_job_name_from_ids(project_id, program_id, job_id) ) - self._make_request(self.grpc_client.cancel_quantum_job, request) + await self._make_request_async(self.grpc_client.cancel_quantum_job, request) - def get_job_results( + cancel_job = duet.sync(cancel_job_async) + + async def get_job_results_async( self, project_id: str, program_id: str, job_id: str ) -> quantum.QuantumResult: """Returns the results of a completed job. @@ -601,9 +703,11 @@ def get_job_results( request = quantum.GetQuantumResultRequest( parent=_job_name_from_ids(project_id, program_id, job_id) ) - return self._make_request(self.grpc_client.get_quantum_result, request) + return await self._make_request_async(self.grpc_client.get_quantum_result, request) + + get_job_results = duet.sync(get_job_results_async) - def list_processors(self, project_id: str) -> List[quantum.QuantumProcessor]: + async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: """Returns a list of Processors that the user has visibility to in the current Engine project. The names of these processors are used to identify devices when scheduling jobs and gathering calibration metrics. @@ -615,10 +719,15 @@ def list_processors(self, project_id: str) -> List[quantum.QuantumProcessor]: A list of metadata of each processor. """ request = quantum.ListQuantumProcessorsRequest(parent=_project_name(project_id), filter='') - response = self._make_request(self.grpc_client.list_quantum_processors, request) - return list(response) + return await self._make_list_request_async( + self.grpc_client.list_quantum_processors, request + ) + + list_processors = duet.sync(list_processors_async) - def get_processor(self, project_id: str, processor_id: str) -> quantum.QuantumProcessor: + async def get_processor_async( + self, project_id: str, processor_id: str + ) -> quantum.QuantumProcessor: """Returns a quantum processor. Args: @@ -631,9 +740,11 @@ def get_processor(self, project_id: str, processor_id: str) -> quantum.QuantumPr request = quantum.GetQuantumProcessorRequest( name=_processor_name_from_ids(project_id, processor_id) ) - return self._make_request(self.grpc_client.get_quantum_processor, request) + return await self._make_request_async(self.grpc_client.get_quantum_processor, request) + + get_processor = duet.sync(get_processor_async) - def list_calibrations( + async def list_calibrations_async( self, project_id: str, processor_id: str, filter_str: str = '' ) -> List[quantum.QuantumCalibration]: """Returns a list of quantum calibrations. @@ -652,10 +763,13 @@ def list_calibrations( request = quantum.ListQuantumCalibrationsRequest( parent=_processor_name_from_ids(project_id, processor_id), filter=filter_str ) - response = self._make_request(self.grpc_client.list_quantum_calibrations, request) - return list(response) + return await self._make_list_request_async( + self.grpc_client.list_quantum_calibrations, request + ) + + list_calibrations = duet.sync(list_calibrations_async) - def get_calibration( + async def get_calibration_async( self, project_id: str, processor_id: str, calibration_timestamp_seconds: int ) -> quantum.QuantumCalibration: """Returns a quantum calibration. @@ -672,9 +786,11 @@ def get_calibration( request = quantum.GetQuantumCalibrationRequest( name=_calibration_name_from_ids(project_id, processor_id, calibration_timestamp_seconds) ) - return self._make_request(self.grpc_client.get_quantum_calibration, request) + return await self._make_request_async(self.grpc_client.get_quantum_calibration, request) + + get_calibration = duet.sync(get_calibration_async) - def get_current_calibration( + async def get_current_calibration_async( self, project_id: str, processor_id: str ) -> Optional[quantum.QuantumCalibration]: """Returns the current quantum calibration for a processor if it has one. @@ -693,13 +809,15 @@ def get_current_calibration( request = quantum.GetQuantumCalibrationRequest( name=_processor_name_from_ids(project_id, processor_id) + '/calibrations/current' ) - return self._make_request(self.grpc_client.get_quantum_calibration, request) + return await self._make_request_async(self.grpc_client.get_quantum_calibration, request) except EngineException as err: if isinstance(err.__cause__, NotFound): return None raise - def create_reservation( + get_current_calibration = duet.sync(get_current_calibration_async) + + async def create_reservation_async( self, project_id: str, processor_id: str, @@ -729,9 +847,13 @@ def create_reservation( request = quantum.CreateQuantumReservationRequest( parent=parent, quantum_reservation=reservation ) - return self._make_request(self.grpc_client.create_quantum_reservation, request) + return await self._make_request_async(self.grpc_client.create_quantum_reservation, request) + + create_reservation = duet.sync(create_reservation_async) - def cancel_reservation(self, project_id: str, processor_id: str, reservation_id: str): + async def cancel_reservation_async( + self, project_id: str, processor_id: str, reservation_id: str + ): """Cancels a quantum reservation. This action is only valid if the associated [QuantumProcessor] @@ -752,9 +874,13 @@ def cancel_reservation(self, project_id: str, processor_id: str, reservation_id: """ name = _reservation_name_from_ids(project_id, processor_id, reservation_id) request = quantum.CancelQuantumReservationRequest(name=name) - return self._make_request(self.grpc_client.cancel_quantum_reservation, request) + return await self._make_request_async(self.grpc_client.cancel_quantum_reservation, request) - def delete_reservation(self, project_id: str, processor_id: str, reservation_id: str): + cancel_reservation = duet.sync(cancel_reservation_async) + + async def delete_reservation_async( + self, project_id: str, processor_id: str, reservation_id: str + ): """Deletes a quantum reservation. This action is only valid if the associated [QuantumProcessor] @@ -771,9 +897,11 @@ def delete_reservation(self, project_id: str, processor_id: str, reservation_id: """ name = _reservation_name_from_ids(project_id, processor_id, reservation_id) request = quantum.DeleteQuantumReservationRequest(name=name) - return self._make_request(self.grpc_client.delete_quantum_reservation, request) + return await self._make_request_async(self.grpc_client.delete_quantum_reservation, request) - def get_reservation( + delete_reservation = duet.sync(delete_reservation_async) + + async def get_reservation_async( self, project_id: str, processor_id: str, reservation_id: str ) -> Optional[quantum.QuantumReservation]: """Gets a quantum reservation from the engine. @@ -789,13 +917,15 @@ def get_reservation( try: name = _reservation_name_from_ids(project_id, processor_id, reservation_id) request = quantum.GetQuantumReservationRequest(name=name) - return self._make_request(self.grpc_client.get_quantum_reservation, request) + return await self._make_request_async(self.grpc_client.get_quantum_reservation, request) except EngineException as err: if isinstance(err.__cause__, NotFound): return None raise - def list_reservations( + get_reservation = duet.sync(get_reservation_async) + + async def list_reservations_async( self, project_id: str, processor_id: str, filter_str: str = '' ) -> List[quantum.QuantumReservation]: """Returns a list of quantum reservations. @@ -819,10 +949,13 @@ def list_reservations( request = quantum.ListQuantumReservationsRequest( parent=_processor_name_from_ids(project_id, processor_id), filter=filter_str ) - response = self._make_request(self.grpc_client.list_quantum_reservations, request) - return list(response) + return await self._make_list_request_async( + self.grpc_client.list_quantum_reservations, request + ) + + list_reservations = duet.sync(list_reservations_async) - def update_reservation( + async def update_reservation_async( self, project_id: str, processor_id: str, @@ -870,9 +1003,11 @@ def update_reservation( quantum_reservation=reservation, update_mask=field_mask_pb2.FieldMask(paths=paths), ) - return self._make_request(self.grpc_client.update_quantum_reservation, request) + return await self._make_request_async(self.grpc_client.update_quantum_reservation, request) - def list_time_slots( + update_reservation = duet.sync(update_reservation_async) + + async def list_time_slots_async( self, project_id: str, processor_id: str, filter_str: str = '' ) -> List[quantum.QuantumTimeSlot]: """Returns a list of quantum time slots on a processor. @@ -890,8 +1025,11 @@ def list_time_slots( request = quantum.ListQuantumTimeSlotsRequest( parent=_processor_name_from_ids(project_id, processor_id), filter=filter_str ) - response = self._make_request(self.grpc_client.list_quantum_time_slots, request) - return list(response) + return await self._make_list_request_async( + self.grpc_client.list_quantum_time_slots, request + ) + + list_time_slots = duet.sync(list_time_slots_async) def _project_name(project_id: str) -> str: diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index d509bfb328a..4d3878b5ffb 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -14,22 +14,26 @@ """Tests for EngineClient.""" import datetime from unittest import mock + import pytest from google.api_core import exceptions from google.protobuf import any_pb2 from google.protobuf.field_mask_pb2 import FieldMask from google.protobuf.timestamp_pb2 import Timestamp + +import duet + from cirq_google.engine.engine_client import EngineClient, EngineException from cirq_google.cloud import quantum def setup_mock_(client_constructor): - grpc_client = mock.Mock() + grpc_client = mock.AsyncMock() client_constructor.return_value = grpc_client return grpc_client -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -95,7 +99,7 @@ def test_create_program(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -114,7 +118,7 @@ def test_get_program(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -160,7 +164,7 @@ def test_list_program(client_constructor): ), ], ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_program_filters( client_constructor, expected_filter, created_before, created_after, labels ): @@ -175,13 +179,13 @@ def test_list_program_filters( assert grpc_client.list_quantum_programs.call_args[0][0].filter == expected_filter -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_program_filters_invalid_type(client_constructor): with pytest.raises(ValueError, match=""): EngineClient().list_programs(project_id='proj', created_before="Unsupported date/time") -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_program_description(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -210,7 +214,7 @@ def test_set_program_description(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_program_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -245,7 +249,7 @@ def test_set_program_labels(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_add_program_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -287,7 +291,7 @@ def test_add_program_labels(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_remove_program_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -327,7 +331,7 @@ def test_remove_program_labels(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -343,7 +347,7 @@ def test_delete_program(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_job(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -474,7 +478,7 @@ def test_create_job(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_job(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -497,7 +501,7 @@ def test_get_job(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_job_description(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -526,7 +530,7 @@ def test_set_job_description(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_job_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -563,7 +567,7 @@ def test_set_job_labels(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_add_job_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -607,7 +611,7 @@ def test_add_job_labels(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_remove_job_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -647,7 +651,7 @@ def test_remove_job_labels(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_job(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -658,7 +662,7 @@ def test_delete_job(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_cancel_job(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -669,7 +673,7 @@ def test_cancel_job(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_job_results(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -683,7 +687,7 @@ def test_job_results(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_jobs(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -787,7 +791,7 @@ def test_list_jobs(client_constructor): ), ], ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_jobs_filters( client_constructor, expected_filter, @@ -813,7 +817,15 @@ def test_list_jobs_filters( assert grpc_client.list_quantum_jobs.call_args[0][0].filter == expected_filter -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +class Pager: + def __init__(self, items): + self.items = items + + def __aiter__(self): + return duet.aiter(self.items) + + +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_processors(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -821,7 +833,7 @@ def test_list_processors(client_constructor): quantum.QuantumProcessor(name='projects/proj/processor/processor0'), quantum.QuantumProcessor(name='projects/proj/processor/processor1'), ] - grpc_client.list_quantum_processors.return_value = results + grpc_client.list_quantum_processors.return_value = Pager(results) client = EngineClient() assert client.list_processors('proj') == results @@ -830,7 +842,7 @@ def test_list_processors(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_processor(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -844,7 +856,7 @@ def test_get_processor(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_calibrations(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -852,7 +864,7 @@ def test_list_calibrations(client_constructor): quantum.QuantumCalibration(name='projects/proj/processor/processor0/calibrations/123456'), quantum.QuantumCalibration(name='projects/proj/processor/processor1/calibrations/224466'), ] - grpc_client.list_quantum_calibrations.return_value = results + grpc_client.list_quantum_calibrations.return_value = Pager(results) client = EngineClient() assert client.list_calibrations('proj', 'processor0') == results @@ -861,7 +873,7 @@ def test_list_calibrations(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_calibration(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -879,7 +891,7 @@ def test_get_calibration(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -897,7 +909,7 @@ def test_get_current_calibration(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration_does_not_exist(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -912,7 +924,7 @@ def test_get_current_calibration_does_not_exist(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration_error(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -923,7 +935,7 @@ def test_get_current_calibration_error(client_constructor): client.get_current_calibration('proj', 'processor0') -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_api_doesnt_retry_not_found_errors(client_constructor): grpc_client = setup_mock_(client_constructor) grpc_client.get_quantum_program.side_effect = exceptions.NotFound('not found') @@ -934,7 +946,7 @@ def test_api_doesnt_retry_not_found_errors(client_constructor): assert grpc_client.get_quantum_program.call_count == 1 -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_api_retry_5xx_errors(client_constructor): grpc_client = setup_mock_(client_constructor) grpc_client.get_quantum_program.side_effect = exceptions.ServiceUnavailable('internal error') @@ -945,9 +957,9 @@ def test_api_retry_5xx_errors(client_constructor): assert grpc_client.get_quantum_program.call_count == 3 -@mock.patch('time.sleep', return_value=None) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) -def test_api_retry_times(client_constructor, mock_time): +@mock.patch('duet.sleep', return_value=duet.completed_future(None)) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +def test_api_retry_times(client_constructor, mock_sleep): grpc_client = setup_mock_(client_constructor) grpc_client.get_quantum_program.side_effect = exceptions.ServiceUnavailable('internal error') @@ -956,11 +968,11 @@ def test_api_retry_times(client_constructor, mock_time): client.get_program('proj', 'prog', False) assert grpc_client.get_quantum_program.call_count == 3 - assert len(mock_time.call_args_list) == 2 - assert all(x == y for (x, _), y in zip(mock_time.call_args_list, [(0.1,), (0.2,)])) + assert len(mock_sleep.call_args_list) == 2 + assert all(x == y for (x, _), y in zip(mock_sleep.call_args_list, [(0.1,), (0.2,)])) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) start = datetime.datetime.fromtimestamp(1000000000) @@ -986,7 +998,7 @@ def test_create_reservation(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_cancel_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' @@ -1005,7 +1017,7 @@ def test_cancel_reservation(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' @@ -1024,7 +1036,7 @@ def test_delete_reservation(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' @@ -1043,7 +1055,7 @@ def test_get_reservation(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation_not_found(client_constructor): grpc_client = setup_mock_(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' @@ -1056,7 +1068,7 @@ def test_get_reservation_not_found(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation_exception(client_constructor): grpc_client = setup_mock_(client_constructor) grpc_client.get_quantum_reservation.side_effect = exceptions.BadRequest('boom') @@ -1066,7 +1078,7 @@ def test_get_reservation_exception(client_constructor): client.get_reservation('proj', 'processor0', 'goog') -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' @@ -1084,13 +1096,13 @@ def test_list_reservation(client_constructor): whitelisted_users=['dstrain@google.com'], ), ] - grpc_client.list_quantum_reservations.return_value = results + grpc_client.list_quantum_reservations.return_value = Pager(results) client = EngineClient() assert client.list_reservations('proj', 'processor0') == results -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_update_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' @@ -1123,7 +1135,7 @@ def test_update_reservation(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_update_reservation_remove_all_users(client_constructor): grpc_client = setup_mock_(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' @@ -1144,7 +1156,7 @@ def test_update_reservation_remove_all_users(client_constructor): ) -@mock.patch.object(quantum, 'QuantumEngineServiceClient', autospec=True) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_time_slots(client_constructor): grpc_client = setup_mock_(client_constructor) results = [ @@ -1167,7 +1179,7 @@ def test_list_time_slots(client_constructor): ), ), ] - grpc_client.list_quantum_time_slots.return_value = results + grpc_client.list_quantum_time_slots.return_value = Pager(results) client = EngineClient() assert client.list_time_slots('proj', 'processor0') == results diff --git a/cirq-google/cirq_google/engine/engine_job.py b/cirq-google/cirq_google/engine/engine_job.py index dda063ec679..5ca191a84a4 100644 --- a/cirq-google/cirq_google/engine/engine_job.py +++ b/cirq-google/cirq_google/engine/engine_job.py @@ -13,10 +13,10 @@ # limitations under the License. """A helper for jobs that have been created on the Quantum Engine.""" import datetime -import time from typing import Dict, Iterator, List, Optional, overload, Sequence, Tuple, TYPE_CHECKING +import duet from google.protobuf import any_pb2 import cirq @@ -107,20 +107,25 @@ def program(self) -> 'engine_program.EngineProgram': return engine_program.EngineProgram(self.project_id, self.program_id, self.context) + async def _get_job_async(self, return_run_context: bool = False) -> quantum.QuantumJob: + return await self.context.client.get_job_async( + self.project_id, self.program_id, self.job_id, return_run_context + ) + + _get_job = duet.sync(_get_job_async) + def _inner_job(self) -> quantum.QuantumJob: if self._job is None: - self._job = self.context.client.get_job( - self.project_id, self.program_id, self.job_id, False - ) + self._job = self._get_job() return self._job - def _refresh_job(self) -> quantum.QuantumJob: + async def _refresh_job_async(self) -> quantum.QuantumJob: if self._job is None or self._job.execution_status.state not in TERMINAL_STATES: - self._job = self.context.client.get_job( - self.project_id, self.program_id, self.job_id, False - ) + self._job = await self._get_job_async() return self._job + _refresh_job = duet.sync(_refresh_job_async) + def create_time(self) -> 'datetime.datetime': """Returns when the job was created.""" return self._inner_job().create_time @@ -224,10 +229,7 @@ def get_repetitions_and_sweeps(self) -> Tuple[int, List[cirq.Sweep]]: A tuple of the repetition count and list of sweeps. """ if self._job is None or self._job.run_context is None: - self._job = self.context.client.get_job( - self.project_id, self.program_id, self.job_id, True - ) - + self._job = self._get_job(return_run_context=True) return _deserialize_run_context(self._job.run_context) def get_processor(self) -> 'Optional[engine_processor.EngineProcessor]': @@ -260,42 +262,26 @@ def delete(self) -> None: """Deletes the job and result, if any.""" self.context.client.delete_job(self.project_id, self.program_id, self.job_id) - def batched_results(self) -> Sequence[Sequence[EngineResult]]: + async def batched_results_async(self) -> Sequence[Sequence[EngineResult]]: """Returns the job results, blocking until the job is complete. This method is intended for batched jobs. Instead of flattening results into a single list, this will return a Sequence[Result] for each circuit in the batch. """ - self.results() + await self.results_async() if self._batched_results is None: raise ValueError('batched_results called for a non-batch result.') return self._batched_results - def _wait_for_result(self): - job = self._refresh_job() - total_seconds_waited = 0.0 - timeout = self.context.timeout - while True: - if timeout and total_seconds_waited >= timeout: - break - if job.execution_status.state in TERMINAL_STATES: - break - time.sleep(0.5) - total_seconds_waited += 0.5 - job = self._refresh_job() - _raise_on_failure(job) - response = self.context.client.get_job_results( - self.project_id, self.program_id, self.job_id - ) - return response.result + batched_results = duet.sync(batched_results_async) - def results(self) -> Sequence[EngineResult]: + async def results_async(self) -> Sequence[EngineResult]: """Returns the job results, blocking until the job is complete.""" import cirq_google.engine.engine as engine_base if self._results is None: - result = self._wait_for_result() + result = await self._await_result_async() result_type = result.type_url[len(engine_base.TYPE_PREFIX) :] if ( result_type == 'cirq.google.api.v1.Result' @@ -317,7 +303,22 @@ def results(self) -> Sequence[EngineResult]: raise ValueError(f'invalid result proto version: {result_type}') return self._results - def calibration_results(self) -> Sequence[CalibrationResult]: + results = duet.sync(results_async) + + async def _await_result_async(self) -> quantum.QuantumResult: + async with duet.timeout_scope(self.context.timeout): + while True: + job = await self._refresh_job_async() + if job.execution_status.state in TERMINAL_STATES: + break + await duet.sleep(0.5) + _raise_on_failure(job) + response = await self.context.client.get_job_results_async( + self.project_id, self.program_id, self.job_id + ) + return response.result + + async def calibration_results_async(self) -> Sequence[CalibrationResult]: """Returns the results of a run_calibration() call. This function will fail if any other type of results were returned @@ -326,7 +327,7 @@ def calibration_results(self) -> Sequence[CalibrationResult]: import cirq_google.engine.engine as engine_base if self._calibration_results is None: - result = self._wait_for_result() + result = await self._await_result_async() result_type = result.type_url[len(engine_base.TYPE_PREFIX) :] if result_type != 'cirq.google.api.v2.FocusedCalibrationResult': raise ValueError(f'Did not find calibration results, instead found: {result_type}') @@ -343,6 +344,8 @@ def calibration_results(self) -> Sequence[CalibrationResult]: self._calibration_results = cal_results return self._calibration_results + calibration_results = duet.sync(calibration_results_async) + def _get_job_results_v1(self, result: v1.program_pb2.Result) -> Sequence[EngineResult]: # coverage: ignore job_id = self.id() diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index 87301983491..4694e835001 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -70,7 +70,7 @@ def test_create_time(): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_update_time(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) get_job.return_value = quantum.QuantumJob( @@ -82,7 +82,7 @@ def test_update_time(get_job): get_job.assert_called_once_with('a', 'b', 'steve', False) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_description(get_job): job = cg.EngineJob( 'a', 'b', 'steve', EngineContext(), _job=quantum.QuantumJob(description='hello') @@ -93,7 +93,7 @@ def test_description(get_job): get_job.assert_called_once_with('a', 'b', 'steve', False) -@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_description') +@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_description_async') def test_set_description(set_job_description): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) set_job_description.return_value = quantum.QuantumJob(description='world') @@ -112,7 +112,7 @@ def test_labels(): assert job.labels() == {'t': '1'} -@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_labels') +@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_labels_async') def test_set_labels(set_job_labels): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) set_job_labels.return_value = quantum.QuantumJob(labels={'a': '1', 'b': '1'}) @@ -124,7 +124,7 @@ def test_set_labels(set_job_labels): set_job_labels.assert_called_with('a', 'b', 'steve', {}) -@mock.patch('cirq_google.engine.engine_client.EngineClient.add_job_labels') +@mock.patch('cirq_google.engine.engine_client.EngineClient.add_job_labels_async') def test_add_labels(add_job_labels): job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=quantum.QuantumJob(labels={})) assert job.labels() == {} @@ -138,7 +138,7 @@ def test_add_labels(add_job_labels): add_job_labels.assert_called_with('a', 'b', 'steve', {'a': '2', 'b': '1'}) -@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_job_labels') +@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_job_labels_async') def test_remove_labels(remove_job_labels): job = cg.EngineJob( 'a', 'b', 'steve', EngineContext(), _job=quantum.QuantumJob(labels={'a': '1', 'b': '1'}) @@ -171,7 +171,7 @@ def test_processor_ids(): assert job.processor_ids() == ['p'] -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_status(get_job): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.RUNNING) @@ -216,7 +216,7 @@ def test_failure_with_no_error(): assert not job.failure() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) get_job.return_value = quantum.QuantumJob( @@ -230,7 +230,7 @@ def test_get_repetitions_and_sweeps(get_job): get_job.assert_called_once_with('a', 'b', 'steve', True) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps_v1(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) get_job.return_value = quantum.QuantumJob( @@ -244,7 +244,7 @@ def test_get_repetitions_and_sweeps_v1(get_job): job.get_repetitions_and_sweeps() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps_unsupported(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) get_job.return_value = quantum.QuantumJob( @@ -312,7 +312,7 @@ def test_get_calibration(get_calibration): get_calibration.assert_called_once_with('a', 'p', 123) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration_async') def test_calibration__with_no_calibration(get_calibration): job = cg.EngineJob( 'a', @@ -329,14 +329,14 @@ def test_calibration__with_no_calibration(get_calibration): assert not get_calibration.called -@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_job') +@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_job_async') def test_cancel(cancel_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) job.cancel() cancel_job.assert_called_once_with('a', 'b', 'steve') -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job') +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job_async') def test_delete(delete_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) job.delete() @@ -504,7 +504,7 @@ def test_delete(delete_job): UPDATE_TIME = datetime.datetime.now(tz=datetime.timezone.utc) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -520,7 +520,7 @@ def test_results(get_job_results): get_job_results.assert_called_once_with('a', 'b', 'steve') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_iter(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -535,7 +535,7 @@ def test_results_iter(get_job_results): assert results[1] == 'q=1010' -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_getitem(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -550,7 +550,7 @@ def test_results_getitem(get_job_results): _ = job[2] -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_batched_results(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -577,7 +577,7 @@ def test_batched_results(get_job_results): assert str(data[1][1]) == 'q=1001' -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_batched_results_not_a_batch(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -589,7 +589,7 @@ def test_batched_results_not_a_batch(get_job_results): job.batched_results() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_results(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -608,7 +608,7 @@ def test_calibration_results(get_job_results): assert data[0].metrics['theta'] == {(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)): [0.9999]} -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_defaults(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -628,7 +628,7 @@ def test_calibration_defaults(get_job_results): assert len(data[0].metrics) == 0 -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_results_not_a_calibration(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -640,7 +640,7 @@ def test_calibration_results_not_a_calibration(get_job_results): job.calibration_results() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_len(get_job_results): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), @@ -652,16 +652,15 @@ def test_results_len(get_job_results): assert len(job) == 2 -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job') -@mock.patch('time.sleep', return_value=None) -def test_timeout(patched_time_sleep, get_job): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') +def test_timeout(get_job): qjob = quantum.QuantumJob( execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.RUNNING), update_time=UPDATE_TIME, ) get_job.return_value = qjob - job = cg.EngineJob('a', 'b', 'steve', EngineContext(timeout=500)) - with pytest.raises(RuntimeError, match='Timed out'): + job = cg.EngineJob('a', 'b', 'steve', EngineContext(timeout=0.1)) + with pytest.raises(TimeoutError): job.results() diff --git a/cirq-google/cirq_google/engine/engine_processor_test.py b/cirq-google/cirq_google/engine/engine_processor_test.py index 9cd45334222..3b058857877 100644 --- a/cirq-google/cirq_google/engine/engine_processor_test.py +++ b/cirq-google/cirq_google/engine/engine_processor_test.py @@ -232,7 +232,7 @@ def test_engine_repr(): assert 'the-processor-id' in repr(processor) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') def test_health(get_processor): get_processor.return_value = quantum.QuantumProcessor(health=quantum.QuantumProcessor.Health.OK) processor = cg.EngineProcessor( @@ -244,7 +244,7 @@ def test_health(get_processor): assert processor.health() == 'OK' -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') def test_expected_down_time(get_processor): processor = cg.EngineProcessor('a', 'p', EngineContext(), _processor=quantum.QuantumProcessor()) assert not processor.expected_down_time() @@ -352,7 +352,7 @@ def test_get_missing_device(): _ = processor.get_device(gate_sets=[_GATE_SET]) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations_async') def test_list_calibrations(list_calibrations): list_calibrations.return_value = [_CALIBRATION] processor = cg.EngineProcessor('a', 'p', EngineContext()) @@ -391,7 +391,7 @@ def test_list_calibrations(list_calibrations): list_calibrations.assert_called_with('a', 'p', f'timestamp >= {today_midnight_timestamp}') -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations_async') def test_list_calibrations_old_params(list_calibrations): # Disable pylint warnings for use of deprecated parameters # pylint: disable=unexpected-keyword-arg @@ -409,7 +409,7 @@ def test_list_calibrations_old_params(list_calibrations): list_calibrations.assert_called_with('a', 'p', 'timestamp <= 1562600000') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration_async') def test_get_calibration(get_calibration): get_calibration.return_value = _CALIBRATION processor = cg.EngineProcessor('a', 'p', EngineContext()) @@ -419,7 +419,7 @@ def test_get_calibration(get_calibration): get_calibration.assert_called_once_with('a', 'p', 1562544000021) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration_async') def test_current_calibration(get_current_calibration): get_current_calibration.return_value = _CALIBRATION processor = cg.EngineProcessor('a', 'p', EngineContext()) @@ -429,7 +429,7 @@ def test_current_calibration(get_current_calibration): get_current_calibration.assert_called_once_with('a', 'p') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration_async') def test_missing_latest_calibration(get_current_calibration): get_current_calibration.return_value = None processor = cg.EngineProcessor('a', 'p', EngineContext()) @@ -437,7 +437,7 @@ def test_missing_latest_calibration(get_current_calibration): get_current_calibration.assert_called_once_with('a', 'p') -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_reservation_async') def test_create_reservation(create_reservation): name = 'projects/proj/processors/p0/reservations/psherman-wallaby-way' result = quantum.QuantumReservation( @@ -462,7 +462,7 @@ def test_create_reservation(create_reservation): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation_async') def test_delete_reservation(delete_reservation): name = 'projects/proj/processors/p0/reservations/rid' result = quantum.QuantumReservation( @@ -477,7 +477,7 @@ def test_delete_reservation(delete_reservation): delete_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation_async') def test_cancel_reservation(cancel_reservation): name = 'projects/proj/processors/p0/reservations/rid' result = quantum.QuantumReservation( @@ -492,8 +492,8 @@ def test_cancel_reservation(cancel_reservation): cancel_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation_async') def test_remove_reservation_delete(delete_reservation, get_reservation): name = 'projects/proj/processors/p0/reservations/rid' now = int(datetime.datetime.now().timestamp()) @@ -515,8 +515,8 @@ def test_remove_reservation_delete(delete_reservation, get_reservation): delete_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') -@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') +@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation_async') def test_remove_reservation_cancel(cancel_reservation, get_reservation): name = 'projects/proj/processors/p0/reservations/rid' now = int(datetime.datetime.now().timestamp()) @@ -538,7 +538,7 @@ def test_remove_reservation_cancel(cancel_reservation, get_reservation): cancel_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_remove_reservation_not_found(get_reservation): get_reservation.return_value = None processor = cg.EngineProcessor( @@ -551,8 +551,8 @@ def test_remove_reservation_not_found(get_reservation): processor.remove_reservation('rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_remove_reservation_failures(get_reservation, get_processor): name = 'projects/proj/processors/p0/reservations/rid' now = int(datetime.datetime.now().timestamp()) @@ -576,7 +576,7 @@ def test_remove_reservation_failures(get_reservation, get_processor): processor.remove_reservation('rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_get_reservation(get_reservation): name = 'projects/proj/processors/p0/reservations/rid' result = quantum.QuantumReservation( @@ -591,7 +591,7 @@ def test_get_reservation(get_reservation): get_reservation.assert_called_once_with('proj', 'p0', 'rid') -@mock.patch('cirq_google.engine.engine_client.EngineClient.update_reservation') +@mock.patch('cirq_google.engine.engine_client.EngineClient.update_reservation_async') def test_update_reservation(update_reservation): name = 'projects/proj/processors/p0/reservations/rid' result = quantum.QuantumReservation( @@ -610,7 +610,7 @@ def test_update_reservation(update_reservation): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations_async') def test_list_reservation(list_reservations): name = 'projects/proj/processors/p0/reservations/rid' results = [ @@ -640,7 +640,7 @@ def test_list_reservation(list_reservations): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') def test_get_schedule(list_time_slots): results = [ quantum.QuantumTimeSlot( @@ -675,7 +675,7 @@ def test_get_schedule(list_time_slots): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') def test_get_schedule_filter_by_time_slot(list_time_slots): results = [ quantum.QuantumTimeSlot( @@ -737,7 +737,7 @@ def wrapper(*args, **kwargs): @_allow_deprecated_freezegun @freezegun.freeze_time() -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') def test_get_schedule_time_filter_behavior(list_time_slots): list_time_slots.return_value = [] processor = cg.EngineProcessor('proj', 'p0', EngineContext()) @@ -781,7 +781,7 @@ def test_get_schedule_time_filter_behavior(list_time_slots): @_allow_deprecated_freezegun @freezegun.freeze_time() -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations') +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations_async') def test_list_reservations_time_filter_behavior(list_reservations): list_reservations.return_value = [] processor = cg.EngineProcessor('proj', 'p0', EngineContext()) @@ -823,22 +823,24 @@ def test_list_reservations_time_filter_behavior(list_reservations): list_reservations.assert_called_with('proj', 'p0', f'start_time < {utc_ts}') -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_params(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( execution_status={'state': 'SUCCESS'}, update_time=_to_timestamp('2019-07-09T23:39:59Z') ) - client().get_job_results.return_value = quantum.QuantumResult(result=util.pack_any(_RESULTS_V2)) + client().get_job_results_async.return_value = quantum.QuantumResult( + result=util.pack_any(_RESULTS_V2) + ) processor = cg.EngineProcessor('a', 'p', EngineContext()) job = processor.run_sweep( @@ -855,36 +857,36 @@ def test_run_sweep_params(client): assert result.job_finished_time is not None assert results == cirq.read_json(json_text=cirq.to_json(results)) - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.run_context_pb2.RunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 for i, v in enumerate([1.0, 2.0]): assert sweeps[i].repetitions == 1 assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( execution_status={'state': 'SUCCESS'}, update_time=_to_timestamp('2019-07-09T23:39:59Z') ) - client().get_job_results.return_value = quantum.QuantumResult(result=_BATCH_RESULTS_V2) + client().get_job_results_async.return_value = quantum.QuantumResult(result=_BATCH_RESULTS_V2) processor = cg.EngineProcessor('a', 'p', EngineContext()) job = processor.run_batch( @@ -900,10 +902,10 @@ def test_run_batch(client): assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} for result in results: assert result.job_id == job.id() - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.batch_pb2.BatchRunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) assert len(run_context.run_contexts) == 2 for idx, rc in enumerate(run_context.run_contexts): sweeps = rc.parameter_sweeps @@ -913,24 +915,26 @@ def test_run_batch(client): assert sweeps[0].sweep.single_sweep.points.points == [1.0, 2.0] if idx == 1: assert sweeps[0].sweep.single_sweep.points.points == [3.0, 4.0] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_calibration(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob(execution_status={'state': 'SUCCESS'}) - client().get_job_results.return_value = quantum.QuantumResult(result=_CALIBRATION_RESULTS_V2) + client().get_job_async.return_value = quantum.QuantumJob(execution_status={'state': 'SUCCESS'}) + client().get_job_results_async.return_value = quantum.QuantumResult( + result=_CALIBRATION_RESULTS_V2 + ) q1 = cirq.GridQubit(2, 3) q2 = cirq.GridQubit(2, 4) @@ -952,7 +956,7 @@ def test_run_calibration(client): assert results[1].error_message == 'Second success' # assert label is correct - client().create_job.assert_called_once_with( + client().create_job_async.assert_called_once_with( project_id='proj', program_id='prog', job_id='job-id', @@ -963,22 +967,24 @@ def test_run_calibration(client): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_sampler(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( execution_status={'state': 'SUCCESS'}, update_time=_to_timestamp('2019-07-09T23:39:59Z') ) - client().get_job_results.return_value = quantum.QuantumResult(result=util.pack_any(_RESULTS_V2)) + client().get_job_results_async.return_value = quantum.QuantumResult( + result=util.pack_any(_RESULTS_V2) + ) processor = cg.EngineProcessor('proj', 'mysim', EngineContext()) sampler = processor.get_sampler() results = sampler.run_sweep( @@ -989,7 +995,7 @@ def test_sampler(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - assert client().create_program.call_args[0][0] == 'proj' + assert client().create_program_async.call_args[0][0] == 'proj' def test_str(): diff --git a/cirq-google/cirq_google/engine/engine_program.py b/cirq-google/cirq_google/engine/engine_program.py index b7d2f28fd79..289f69d8df4 100644 --- a/cirq-google/cirq_google/engine/engine_program.py +++ b/cirq-google/cirq_google/engine/engine_program.py @@ -15,6 +15,7 @@ import datetime from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union +import duet from google.protobuf import any_pb2 import cirq @@ -63,7 +64,7 @@ def __init__( self._program = _program self.result_type = result_type - def run_sweep( + async def run_sweep_async( self, job_id: Optional[str] = None, params: cirq.Sweepable = None, @@ -105,7 +106,7 @@ def run_sweep( job_id = engine_base._make_random_id('job-') run_context = self.context._serialize_run_context(params, repetitions) - created_job_id, job = self.context.client.create_job( + created_job_id, job = await self.context.client.create_job_async( project_id=self.project_id, program_id=self.program_id, job_id=job_id, @@ -118,7 +119,9 @@ def run_sweep( self.project_id, self.program_id, created_job_id, self.context, job ) - def run_batch( + run_sweep = duet.sync(run_sweep_async) + + async def run_batch_async( self, job_id: Optional[str] = None, params_list: List[cirq.Sweepable] = None, @@ -179,7 +182,7 @@ def run_batch( (params, repetitions) for params in params_list ) - created_job_id, job = self.context.client.create_job( + created_job_id, job = await self.context.client.create_job_async( project_id=self.project_id, program_id=self.program_id, job_id=job_id, @@ -197,7 +200,9 @@ def run_batch( result_type=ResultType.Batch, ) - def run_calibration( + run_batch = duet.sync(run_batch_async) + + async def run_calibration_async( self, job_id: Optional[str] = None, processor_ids: Sequence[str] = (), @@ -241,7 +246,7 @@ def run_calibration( # on a run context in order to succeed validation. run_context = v2.run_context_pb2.RunContext() - created_job_id, job = self.context.client.create_job( + created_job_id, job = await self.context.client.create_job_async( project_id=self.project_id, program_id=self.program_id, job_id=job_id, @@ -259,7 +264,9 @@ def run_calibration( result_type=ResultType.Batch, ) - def run( + run_calibration = duet.sync(run_calibration_async) + + async def run_async( self, job_id: Optional[str] = None, param_resolver: cirq.ParamResolver = cirq.ParamResolver({}), @@ -286,16 +293,18 @@ def run( Returns: A single Result for this run. """ - return list( - self.run_sweep( - job_id=job_id, - params=[param_resolver], - repetitions=repetitions, - processor_ids=processor_ids, - description=description, - labels=labels, - ) - )[0] + job = await self.run_sweep_async( + job_id=job_id, + params=[param_resolver], + repetitions=repetitions, + processor_ids=processor_ids, + description=description, + labels=labels, + ) + results = await job.results_async() + return results[0] + + run = duet.sync(run_async) def engine(self) -> 'engine_base.Engine': """Returns the parent Engine object. @@ -318,13 +327,13 @@ def get_job(self, job_id: str) -> engine_job.EngineJob: """ return engine_job.EngineJob(self.project_id, self.program_id, job_id, self.context) - def list_jobs( + async def list_jobs_async( self, created_before: Optional[Union[datetime.datetime, datetime.date]] = None, created_after: Optional[Union[datetime.datetime, datetime.date]] = None, has_labels: Optional[Dict[str, str]] = None, execution_states: Optional[Set[quantum.ExecutionStatus.State]] = None, - ): + ) -> Sequence[engine_job.EngineJob]: """Returns the list of jobs for this program. Args: @@ -347,7 +356,7 @@ def list_jobs( `quantum.ExecutionStatus.State` enum for accepted values. """ client = self.context.client - response = client.list_jobs( + response = await client.list_jobs_async( self.project_id, self.program_id, created_before=created_before, @@ -366,6 +375,8 @@ def list_jobs( for j in response ] + list_jobs = duet.sync(list_jobs_async) + def _inner_program(self) -> quantum.QuantumProgram: if self._program is None: self._program = self.context.client.get_program(self.project_id, self.program_id, False) @@ -384,7 +395,7 @@ def description(self) -> str: """Returns the description of the program.""" return self._inner_program().description - def set_description(self, description: str) -> 'EngineProgram': + async def set_description_async(self, description: str) -> 'EngineProgram': """Sets the description of the program. Params: @@ -393,16 +404,18 @@ def set_description(self, description: str) -> 'EngineProgram': Returns: This EngineProgram. """ - self._program = self.context.client.set_program_description( + self._program = await self.context.client.set_program_description_async( self.project_id, self.program_id, description ) return self + set_description = duet.sync(set_description_async) + def labels(self) -> Dict[str, str]: """Returns the labels of the program.""" return self._inner_program().labels - def set_labels(self, labels: Dict[str, str]) -> 'EngineProgram': + async def set_labels_async(self, labels: Dict[str, str]) -> 'EngineProgram': """Sets (overwriting) the labels for a previously created quantum program. @@ -412,12 +425,14 @@ def set_labels(self, labels: Dict[str, str]) -> 'EngineProgram': Returns: This EngineProgram. """ - self._program = self.context.client.set_program_labels( + self._program = await self.context.client.set_program_labels_async( self.project_id, self.program_id, labels ) return self - def add_labels(self, labels: Dict[str, str]) -> 'EngineProgram': + set_labels = duet.sync(set_labels_async) + + async def add_labels_async(self, labels: Dict[str, str]) -> 'EngineProgram': """Adds new labels to a previously created quantum program. Params: @@ -426,12 +441,14 @@ def add_labels(self, labels: Dict[str, str]) -> 'EngineProgram': Returns: This EngineProgram. """ - self._program = self.context.client.add_program_labels( + self._program = await self.context.client.add_program_labels_async( self.project_id, self.program_id, labels ) return self - def remove_labels(self, keys: List[str]) -> 'EngineProgram': + add_labels = duet.sync(add_labels_async) + + async def remove_labels_async(self, keys: List[str]) -> 'EngineProgram': """Removes labels with given keys from the labels of a previously created quantum program. @@ -441,12 +458,14 @@ def remove_labels(self, keys: List[str]) -> 'EngineProgram': Returns: This EngineProgram. """ - self._program = self.context.client.remove_program_labels( + self._program = await self.context.client.remove_program_labels_async( self.project_id, self.program_id, keys ) return self - def get_circuit(self, program_num: Optional[int] = None) -> cirq.Circuit: + remove_labels = duet.sync(remove_labels_async) + + async def get_circuit_async(self, program_num: Optional[int] = None) -> cirq.Circuit: """Returns the cirq Circuit for the Quantum Engine program. This is only supported if the program was created with the V2 protos. @@ -459,10 +478,14 @@ def get_circuit(self, program_num: Optional[int] = None) -> cirq.Circuit: The program's cirq Circuit. """ if self._program is None or self._program.code is None: - self._program = self.context.client.get_program(self.project_id, self.program_id, True) + self._program = await self.context.client.get_program_async( + self.project_id, self.program_id, True + ) return _deserialize_program(self._program.code, program_num) - def batch_size(self) -> int: + get_circuit = duet.sync(get_circuit_async) + + async def batch_size_async(self) -> int: """Returns the number of programs in a batch program. Raises: @@ -475,7 +498,9 @@ def batch_size(self) -> int: import cirq_google.engine.engine as engine_base if self._program is None or self._program.code is None: - self._program = self.context.client.get_program(self.project_id, self.program_id, True) + self._program = await self.context.client.get_program_async( + self.project_id, self.program_id, True + ) code = self._program.code code_type = code.type_url[len(engine_base.TYPE_PREFIX) :] if code_type == 'cirq.google.api.v2.BatchProgram': @@ -483,20 +508,26 @@ def batch_size(self) -> int: return len(batch.programs) raise ValueError(f'Program was not a batch program but instead was of type {code_type}.') - def delete(self, delete_jobs: bool = False) -> None: + batch_size = duet.sync(batch_size_async) + + async def delete_async(self, delete_jobs: bool = False) -> None: """Deletes a previously created quantum program. Params: delete_jobs: If True will delete all the program's jobs, other this will fail if the program contains any jobs. """ - self.context.client.delete_program( + await self.context.client.delete_program_async( self.project_id, self.program_id, delete_jobs=delete_jobs ) - def delete_job(self, job_id: str) -> None: + delete = duet.sync(delete_async) + + async def delete_job_async(self, job_id: str) -> None: """Deletes the job and result, if any.""" - self.context.client.delete_job(self.project_id, self.program_id, job_id) + await self.context.client.delete_job_async(self.project_id, self.program_id, job_id) + + delete_job = duet.sync(delete_job_async) def __str__(self) -> str: return f'EngineProgram(project_id=\'{self.project_id}\', program_id=\'{self.program_id}\')' diff --git a/cirq-google/cirq_google/engine/engine_program_test.py b/cirq-google/cirq_google/engine/engine_program_test.py index a232df279c3..9ce307529c5 100644 --- a/cirq-google/cirq_google/engine/engine_program_test.py +++ b/cirq-google/cirq_google/engine/engine_program_test.py @@ -163,9 +163,9 @@ ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_sweeps_delegation(create_job): - create_job.return_value = ('steve', quantum.QuantumJob()) +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_sweeps_delegation(create_job_async): + create_job_async.return_value = ('steve', quantum.QuantumJob()) program = cg.EngineProgram('my-proj', 'my-prog', EngineContext()) param_resolver = cirq.ParamResolver({}) job = program.run_sweep( @@ -174,9 +174,9 @@ def test_run_sweeps_delegation(create_job): assert job._job == quantum.QuantumJob() -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_batch_delegation(create_job): - create_job.return_value = ('kittens', quantum.QuantumJob()) +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_batch_delegation(create_job_async): + create_job_async.return_value = ('kittens', quantum.QuantumJob()) program = cg.EngineProgram('my-meow', 'my-meow', EngineContext(), result_type=ResultType.Batch) resolver_list = [cirq.Points('cats', [1.0, 2.0, 3.0]), cirq.Points('cats', [4.0, 5.0, 6.0])] job = program.run_batch( @@ -185,27 +185,27 @@ def test_run_batch_delegation(create_job): assert job._job == quantum.QuantumJob() -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_calibration_delegation(create_job): - create_job.return_value = ('dogs', quantum.QuantumJob()) +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_calibration_delegation(create_job_async): + create_job_async.return_value = ('dogs', quantum.QuantumJob()) program = cg.EngineProgram('woof', 'woof', EngineContext(), result_type=ResultType.Calibration) job = program.run_calibration(processor_ids=['lazydog']) assert job._job == quantum.QuantumJob() -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_calibration_no_processors(create_job): - create_job.return_value = ('dogs', quantum.QuantumJob()) +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_calibration_no_processors(create_job_async): + create_job_async.return_value = ('dogs', quantum.QuantumJob()) program = cg.EngineProgram('woof', 'woof', EngineContext(), result_type=ResultType.Calibration) with pytest.raises(ValueError, match='No processors specified'): _ = program.run_calibration(job_id='spot') -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_batch_no_sweeps(create_job): +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_batch_no_sweeps(create_job_async): # Running with no sweeps is fine. Uses program's batch size to create # proper empty sweeps. - create_job.return_value = ('kittens', quantum.QuantumJob()) + create_job_async.return_value = ('kittens', quantum.QuantumJob()) program = cg.EngineProgram( 'my-meow', 'my-meow', @@ -216,7 +216,7 @@ def test_run_batch_no_sweeps(create_job): job = program.run_batch(job_id='steve', repetitions=10, processor_ids=['lazykitty']) assert job._job == quantum.QuantumJob() batch_run_context = v2.batch_pb2.BatchRunContext() - create_job.call_args[1]['run_context'].Unpack(batch_run_context) + create_job_async.call_args[1]['run_context'].Unpack(batch_run_context) assert len(batch_run_context.run_contexts) == 1 @@ -242,11 +242,11 @@ def test_run_in_batch_mode(): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results') -@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job') -def test_run_delegation(create_job, get_results): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') +@mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') +def test_run_delegation(create_job_async, get_results_async): dt = datetime.datetime.now(tz=datetime.timezone.utc) - create_job.return_value = ( + create_job_async.return_value = ( 'steve', quantum.QuantumJob( name='projects/a/programs/b/jobs/steve', @@ -254,7 +254,7 @@ def test_run_delegation(create_job, get_results): update_time=dt, ), ) - get_results.return_value = quantum.QuantumResult( + get_results_async.return_value = quantum.QuantumResult( result=util.pack_any( Merge( """sweep_results: [{ @@ -297,15 +297,15 @@ def test_run_delegation(create_job, get_results): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs') -def test_list_jobs(list_jobs): +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs_async') +def test_list_jobs(list_jobs_async): job1 = quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job1') job2 = quantum.QuantumJob(name='projects/otherproj/programs/prog1/jobs/job2') - list_jobs.return_value = [job1, job2] + list_jobs_async.return_value = [job1, job2] ctx = EngineContext() result = cg.EngineProgram(project_id='proj', program_id='prog1', context=ctx).list_jobs() - list_jobs.assert_called_once_with( + list_jobs_async.assert_called_once_with( 'proj', 'prog1', created_after=None, @@ -341,40 +341,40 @@ def test_create_time(): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_update_time(get_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_update_time(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram( + get_program_async.return_value = quantum.QuantumProgram( update_time=timestamp_pb2.Timestamp(seconds=1581515101) ) assert program.update_time() == datetime.datetime( 2020, 2, 12, 13, 45, 1, tzinfo=datetime.timezone.utc ) - get_program.assert_called_once_with('a', 'b', False) + get_program_async.assert_called_once_with('a', 'b', False) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_description(get_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_description(get_program_async): program = cg.EngineProgram( 'a', 'b', EngineContext(), _program=quantum.QuantumProgram(description='hello') ) assert program.description() == 'hello' - get_program.return_value = quantum.QuantumProgram(description='hello') + get_program_async.return_value = quantum.QuantumProgram(description='hello') assert cg.EngineProgram('a', 'b', EngineContext()).description() == 'hello' - get_program.assert_called_once_with('a', 'b', False) + get_program_async.assert_called_once_with('a', 'b', False) -@mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_description') -def test_set_description(set_program_description): +@mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_description_async') +def test_set_description(set_program_description_async): program = cg.EngineProgram('a', 'b', EngineContext()) - set_program_description.return_value = quantum.QuantumProgram(description='world') + set_program_description_async.return_value = quantum.QuantumProgram(description='world') assert program.set_description('world').description() == 'world' - set_program_description.assert_called_with('a', 'b', 'world') + set_program_description_async.assert_called_with('a', 'b', 'world') - set_program_description.return_value = quantum.QuantumProgram(description='') + set_program_description_async.return_value = quantum.QuantumProgram(description='') assert program.set_description('').description() == '' - set_program_description.assert_called_with('a', 'b', '') + set_program_description_async.assert_called_with('a', 'b', '') def test_labels(): @@ -384,92 +384,94 @@ def test_labels(): assert program.labels() == {'t': '1'} -@mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_labels') -def test_set_labels(set_program_labels): +@mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_labels_async') +def test_set_labels(set_program_labels_async): program = cg.EngineProgram('a', 'b', EngineContext()) - set_program_labels.return_value = quantum.QuantumProgram(labels={'a': '1', 'b': '1'}) + set_program_labels_async.return_value = quantum.QuantumProgram(labels={'a': '1', 'b': '1'}) assert program.set_labels({'a': '1', 'b': '1'}).labels() == {'a': '1', 'b': '1'} - set_program_labels.assert_called_with('a', 'b', {'a': '1', 'b': '1'}) + set_program_labels_async.assert_called_with('a', 'b', {'a': '1', 'b': '1'}) - set_program_labels.return_value = quantum.QuantumProgram() + set_program_labels_async.return_value = quantum.QuantumProgram() assert program.set_labels({}).labels() == {} - set_program_labels.assert_called_with('a', 'b', {}) + set_program_labels_async.assert_called_with('a', 'b', {}) -@mock.patch('cirq_google.engine.engine_client.EngineClient.add_program_labels') -def test_add_labels(add_program_labels): +@mock.patch('cirq_google.engine.engine_client.EngineClient.add_program_labels_async') +def test_add_labels(add_program_labels_async): program = cg.EngineProgram( 'a', 'b', EngineContext(), _program=quantum.QuantumProgram(labels={}) ) assert program.labels() == {} - add_program_labels.return_value = quantum.QuantumProgram(labels={'a': '1'}) + add_program_labels_async.return_value = quantum.QuantumProgram(labels={'a': '1'}) assert program.add_labels({'a': '1'}).labels() == {'a': '1'} - add_program_labels.assert_called_with('a', 'b', {'a': '1'}) + add_program_labels_async.assert_called_with('a', 'b', {'a': '1'}) - add_program_labels.return_value = quantum.QuantumProgram(labels={'a': '2', 'b': '1'}) + add_program_labels_async.return_value = quantum.QuantumProgram(labels={'a': '2', 'b': '1'}) assert program.add_labels({'a': '2', 'b': '1'}).labels() == {'a': '2', 'b': '1'} - add_program_labels.assert_called_with('a', 'b', {'a': '2', 'b': '1'}) + add_program_labels_async.assert_called_with('a', 'b', {'a': '2', 'b': '1'}) -@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_program_labels') -def test_remove_labels(remove_program_labels): +@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_program_labels_async') +def test_remove_labels(remove_program_labels_async): program = cg.EngineProgram( 'a', 'b', EngineContext(), _program=quantum.QuantumProgram(labels={'a': '1', 'b': '1'}) ) assert program.labels() == {'a': '1', 'b': '1'} - remove_program_labels.return_value = quantum.QuantumProgram(labels={'b': '1'}) + remove_program_labels_async.return_value = quantum.QuantumProgram(labels={'b': '1'}) assert program.remove_labels(['a']).labels() == {'b': '1'} - remove_program_labels.assert_called_with('a', 'b', ['a']) + remove_program_labels_async.assert_called_with('a', 'b', ['a']) - remove_program_labels.return_value = quantum.QuantumProgram(labels={}) + remove_program_labels_async.return_value = quantum.QuantumProgram(labels={}) assert program.remove_labels(['a', 'b', 'c']).labels() == {} - remove_program_labels.assert_called_with('a', 'b', ['a', 'b', 'c']) + remove_program_labels_async.assert_called_with('a', 'b', ['a', 'b', 'c']) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_v1(get_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_v1(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram(code=util.pack_any(v1.program_pb2.Program())) + get_program_async.return_value = quantum.QuantumProgram( + code=util.pack_any(v1.program_pb2.Program()) + ) with pytest.raises(ValueError, match='v1 Program is not supported'): program.get_circuit() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_v2(get_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_v2(get_program_async): circuit = cirq.Circuit( cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result') ) program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) assert program.get_circuit() == circuit - get_program.assert_called_once_with('a', 'b', True) + get_program_async.assert_called_once_with('a', 'b', True) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_batch(get_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_batch(get_program_async): circuit = cirq.Circuit( cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result') ) program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) with pytest.raises(ValueError, match='A program number must be specified'): program.get_circuit() with pytest.raises(ValueError, match='Only 1 in the batch but index 1 was specified'): program.get_circuit(1) assert program.get_circuit(0) == circuit - get_program.assert_called_once_with('a', 'b', True) + get_program_async.assert_called_once_with('a', 'b', True) -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_batch_size(get_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_batch_size(get_program_async): # Has to fetch from engine if not _program specified. program = cg.EngineProgram('a', 'b', EngineContext(), result_type=ResultType.Batch) - get_program.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) + get_program_async.return_value = quantum.QuantumProgram(code=_BATCH_PROGRAM_V2) assert program.batch_size() == 1 # If _program specified, uses that value. @@ -487,7 +489,7 @@ def test_get_batch_size(get_program): _ = program.batch_size() with pytest.raises(ValueError, match='cirq.google.api.v2.Program'): - get_program.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) + get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2) program = cg.EngineProgram('a', 'b', EngineContext(), result_type=ResultType.Batch) _ = program.batch_size() @@ -500,10 +502,10 @@ def mock_grpc_client(): yield _fixture -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_v2_unknown_gateset(get_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_v2_unknown_gateset(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram( + get_program_async.return_value = quantum.QuantumProgram( code=util.pack_any( v2.program_pb2.Program(language=v2.program_pb2.Language(gate_set="BAD_GATESET")) ) @@ -513,10 +515,10 @@ def test_get_circuit_v2_unknown_gateset(get_program): program.get_circuit() -@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program') -def test_get_circuit_unsupported_program_type(get_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') +def test_get_circuit_unsupported_program_type(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) - get_program.return_value = quantum.QuantumProgram( + get_program_async.return_value = quantum.QuantumProgram( code=any_pb2.Any(type_url='type.googleapis.com/unknown.proto') ) @@ -524,21 +526,21 @@ def test_get_circuit_unsupported_program_type(get_program): program.get_circuit() -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_program') -def test_delete(delete_program): +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_program_async') +def test_delete(delete_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) program.delete() - delete_program.assert_called_with('a', 'b', delete_jobs=False) + delete_program_async.assert_called_with('a', 'b', delete_jobs=False) program.delete(delete_jobs=True) - delete_program.assert_called_with('a', 'b', delete_jobs=True) + delete_program_async.assert_called_with('a', 'b', delete_jobs=True) -@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job') -def test_delete_jobs(delete_job): +@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job_async') +def test_delete_jobs(delete_job_async): program = cg.EngineProgram('a', 'b', EngineContext()) program.delete_job('c') - delete_job.assert_called_with('a', 'b', 'c') + delete_job_async.assert_called_with('a', 'b', 'c') def test_str(): diff --git a/cirq-google/cirq_google/engine/engine_sampler.py b/cirq-google/cirq_google/engine/engine_sampler.py index 4af074b5923..aff1ce24bc7 100644 --- a/cirq-google/cirq_google/engine/engine_sampler.py +++ b/cirq-google/cirq_google/engine/engine_sampler.py @@ -14,6 +14,8 @@ from typing import List, Optional, Sequence, TYPE_CHECKING, Union +import duet + import cirq from cirq_google import engine from cirq_google.engine import util @@ -48,26 +50,28 @@ def __init__( self._processor_ids = [processor_id] if isinstance(processor_id, str) else processor_id self._engine = engine - def run_sweep( + async def run_sweep_async( self, program: Union[cirq.AbstractCircuit, 'cirq_google.EngineProgram'], params: cirq.Sweepable, repetitions: int = 1, ) -> Sequence[cirq.Result]: if isinstance(program, engine.EngineProgram): - job = program.run_sweep( + job = await program.run_sweep_async( params=params, repetitions=repetitions, processor_ids=self._processor_ids ) else: - job = self._engine.run_sweep( + job = await self._engine.run_sweep_async( program=program, params=params, repetitions=repetitions, processor_ids=self._processor_ids, ) - return job.results() + return await job.results_async() + + run_sweep = duet.sync(run_sweep_async) - def run_batch( + async def run_batch_async( self, programs: Sequence[cirq.AbstractCircuit], params_list: Optional[List[cirq.Sweepable]] = None, @@ -91,16 +95,18 @@ def run_batch( # All repetitions are the same so batching can be done efficiently if isinstance(repetitions, List): repetitions = repetitions[0] - job = self._engine.run_batch( + job = await self._engine.run_batch_async( programs=programs, params_list=params_list, repetitions=repetitions, processor_ids=self._processor_ids, ) - return job.batched_results() + return await job.batched_results_async() # Varying number of repetitions so no speedup return super().run_batch(programs, params_list, repetitions) + run_batch = duet.sync(run_batch_async) + @property def engine(self) -> 'cirq_google.Engine': return self._engine diff --git a/cirq-google/cirq_google/engine/engine_sampler_test.py b/cirq-google/cirq_google/engine/engine_sampler_test.py index bb194e94dd0..94e4e27c784 100644 --- a/cirq-google/cirq_google/engine/engine_sampler_test.py +++ b/cirq-google/cirq_google/engine/engine_sampler_test.py @@ -24,26 +24,30 @@ @pytest.mark.parametrize('circuit', [cirq.Circuit(), cirq.FrozenCircuit()]) def test_run_circuit(circuit): engine = mock.Mock() + engine.run_sweep_async = mock.AsyncMock() sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp') params = [cirq.ParamResolver({'a': 1})] sampler.run_sweep(circuit, params, 5) - engine.run_sweep.assert_called_with( + engine.run_sweep_async.assert_called_with( params=params, processor_ids=['tmp'], program=circuit, repetitions=5 ) def test_run_engine_program(): engine = mock.Mock() + engine.run_sweep_async = mock.AsyncMock() sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp') program = mock.Mock(spec=cg.EngineProgram) + program.run_sweep_async = mock.AsyncMock() params = [cirq.ParamResolver({'a': 1})] sampler.run_sweep(program, params, 5) - program.run_sweep.assert_called_with(params=params, processor_ids=['tmp'], repetitions=5) - engine.run_sweep.assert_not_called() + program.run_sweep_async.assert_called_with(params=params, processor_ids=['tmp'], repetitions=5) + engine.run_sweep_async.assert_not_called() def test_run_batch(): engine = mock.Mock() + engine.run_batch_async = mock.AsyncMock() sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp') a = cirq.LineQubit(0) circuit1 = cirq.Circuit(cirq.X(a)) @@ -53,13 +57,14 @@ def test_run_batch(): circuits = [circuit1, circuit2] params_list = [params1, params2] sampler.run_batch(circuits, params_list, 5) - engine.run_batch.assert_called_with( + engine.run_batch_async.assert_called_with( params_list=params_list, processor_ids=['tmp'], programs=circuits, repetitions=5 ) def test_run_batch_identical_repetitions(): engine = mock.Mock() + engine.run_batch_async = mock.AsyncMock() sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp') a = cirq.LineQubit(0) circuit1 = cirq.Circuit(cirq.X(a)) @@ -69,7 +74,7 @@ def test_run_batch_identical_repetitions(): circuits = [circuit1, circuit2] params_list = [params1, params2] sampler.run_batch(circuits, params_list, [5, 5]) - engine.run_batch.assert_called_with( + engine.run_batch_async.assert_called_with( params_list=params_list, processor_ids=['tmp'], programs=circuits, repetitions=5 ) @@ -91,8 +96,10 @@ def test_run_batch_bad_number_of_repetitions(): def test_run_batch_differing_repetitions(): engine = mock.Mock() job = mock.Mock() - job.results.return_value = [] - engine.run_sweep.return_value = job + job.results_async = mock.AsyncMock() + job.results_async.return_value = [] + engine.run_sweep_async = mock.AsyncMock() + engine.run_sweep_async.return_value = job sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp') a = cirq.LineQubit(0) circuit1 = cirq.Circuit(cirq.X(a)) @@ -103,10 +110,10 @@ def test_run_batch_differing_repetitions(): params_list = [params1, params2] repetitions = [1, 2] sampler.run_batch(circuits, params_list, repetitions) - engine.run_sweep.assert_called_with( + engine.run_sweep_async.assert_called_with( params=params2, processor_ids=['tmp'], program=circuit2, repetitions=2 ) - engine.run_batch.assert_not_called() + engine.run_batch_async.assert_not_called() def test_engine_sampler_engine_property(): @@ -123,6 +130,6 @@ def test_get_engine_sampler_explicit_project_id(): def test_get_engine_sampler(): with mock.patch.object(cirq_google.cloud.quantum, 'QuantumEngineServiceClient', autospec=True): - with mock.patch('google.auth.default', lambda: (None, 'myproj')): + with mock.patch('google.auth.default', lambda *args, **kwargs: (None, 'myproj')): sampler = cg.get_engine_sampler(processor_id='hi mom') assert hasattr(sampler, 'run_sweep') diff --git a/cirq-google/cirq_google/engine/engine_test.py b/cirq-google/cirq_google/engine/engine_test.py index 6ca9c203955..8dc3d6bf021 100644 --- a/cirq-google/cirq_google/engine/engine_test.py +++ b/cirq-google/cirq_google/engine/engine_test.py @@ -276,9 +276,9 @@ def test_make_random_id(): @pytest.fixture(scope='session', autouse=True) -def mock_grpc_client(): +def mock_grpc_client_async(): with mock.patch( - 'cirq_google.engine.engine_client.quantum.QuantumEngineServiceClient' + 'cirq_google.engine.engine_client.quantum.QuantumEngineServiceAsyncClient', autospec=True ) as _fixture: yield _fixture @@ -299,7 +299,7 @@ def test_create_context(client): assert context.copy() == context -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_create_engine(client): with pytest.raises( ValueError, match='provide context or proto_version, service_args and verbose' @@ -321,40 +321,37 @@ def test_create_engine(client): ).context.proto_version == cg.engine.engine.ProtoVersion.V2 ) - assert client.called_with({'args': 'test'}, True) + client.assert_called_with({'args': 'test'}, True) def test_engine_str(): engine = cg.Engine( - 'proj', - proto_version=cg.engine.engine.ProtoVersion.V2, - service_args={'args': 'test'}, - verbose=True, + 'proj', proto_version=cg.engine.engine.ProtoVersion.V2, service_args={}, verbose=True ) - assert str(engine) == 'Engine(project_id=\'proj\')' + assert str(engine) == "Engine(project_id='proj')" _DT = datetime.datetime.now(tz=datetime.timezone.utc) def setup_run_circuit_with_result_(client, result): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( execution_status={'state': 'SUCCESS'}, update_time=_DT ) - client().get_job_results.return_value = quantum.QuantumResult(result=result) + client().get_job_results_async.return_value = quantum.QuantumResult(result=result) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit(client): setup_run_circuit_with_result_(client, _A_RESULT) @@ -367,27 +364,22 @@ def test_run_circuit(client): assert result.params.param_dict == {'a': 1} assert result.measurements == {'q': np.array([[0]], dtype='uint8')} client.assert_called_with(service_args={'client_info': 1}, verbose=None) - client.create_program.called_once_with() - client.create_job.called_once_with( - 'projects/project-id/programs/test', - quantum.QuantumJob( - name='projects/project-id/programs/test/jobs/job-id', - scheduling_config={ - 'priority': 50, - 'processor_selector': {'processor_names': ['projects/project-id/processors/mysim']}, - }, - run_context=util.pack_any( - v2.run_context_pb2.RunContext( - parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=1)] - ) - ), - update_time=_DT, + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once_with( + project_id='proj', + program_id='prog', + job_id='job-id', + processor_ids=['mysim'], + run_context=util.pack_any( + v2.run_context_pb2.RunContext( + parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=1)] + ) ), - False, + description=None, + labels=None, ) - - client.get_job.called_once_with('proj', 'prog') - client.get_job_result.called_once_with() + client().get_job_async.assert_called_once_with('proj', 'prog', 'job-id', False) + client().get_job_results_async.assert_called_once_with('proj', 'prog', 'job-id') def test_no_gate_set(): @@ -401,19 +393,19 @@ def test_unsupported_program_type(): engine.run(program="this isn't even the right type of thing!") -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_failed(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={ 'state': 'FAILURE', @@ -431,19 +423,19 @@ def test_run_circuit_failed(client): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_failed_missing_processor_name(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={ 'state': 'FAILURE', @@ -460,19 +452,19 @@ def test_run_circuit_failed_missing_processor_name(client): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_cancelled(client): - client().create_program.return_value = ( + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'CANCELLED'} ) @@ -483,29 +475,28 @@ def test_run_circuit_cancelled(client): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient') -@mock.patch('time.sleep', return_value=None) -def test_run_circuit_timeout(patched_time_sleep, client): - client().create_program.return_value = ( +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) +def test_run_circuit_timeout(client): + client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), ) - client().create_job.return_value = ( + client().create_job_async.return_value = ( 'job-id', quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} ), ) - client().get_job.return_value = quantum.QuantumJob( + client().get_job_async.return_value = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'RUNNING'} ) - engine = cg.Engine(project_id='project-id', timeout=600) - with pytest.raises(RuntimeError, match='Timed out'): + engine = cg.Engine(project_id='project-id', timeout=1) + with pytest.raises(TimeoutError): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_params(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -520,21 +511,21 @@ def test_run_sweep_params(client): assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.run_context_pb2.RunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 for i, v in enumerate([1.0, 2.0]): assert sweeps[i].repetitions == 1 assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_multiple_times(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -542,10 +533,10 @@ def test_run_multiple_times(client): program = engine.create_program(program=_CIRCUIT) program.run(param_resolver=cirq.ParamResolver({'a': 1})) run_context = v2.run_context_pb2.RunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps1 = run_context.parameter_sweeps job2 = program.run_sweep(repetitions=2, params=cirq.Points('a', [3, 4])) - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps2 = run_context.parameter_sweeps results = job2.results() assert engine.context.proto_version == cg.engine.engine.ProtoVersion.V2 @@ -561,11 +552,11 @@ def test_run_multiple_times(client): assert len(sweeps2) == 1 assert sweeps2[0].repetitions == 2 assert sweeps2[0].sweep.single_sweep.points.points == [3, 4] - assert client().get_job.call_count == 2 - assert client().get_job_results.call_count == 2 + assert client().get_job_async.call_count == 2 + assert client().get_job_results_async.call_count == 2 -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_v2(client): setup_run_circuit_with_result_(client, _RESULTS_V2) @@ -577,19 +568,19 @@ def test_run_sweep_v2(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.run_context_pb2.RunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 1 assert sweeps[0].repetitions == 1 assert sweeps[0].sweep.single_sweep.points.points == [1, 2] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch(client): setup_run_circuit_with_result_(client, _BATCH_RESULTS_V2) @@ -606,10 +597,10 @@ def test_run_batch(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - client().create_program.assert_called_once() - client().create_job.assert_called_once() + client().create_program_async.assert_called_once() + client().create_job_async.assert_called_once() run_context = v2.batch_pb2.BatchRunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) assert len(run_context.run_contexts) == 2 for idx, rc in enumerate(run_context.run_contexts): sweeps = rc.parameter_sweeps @@ -619,11 +610,11 @@ def test_run_batch(client): assert sweeps[0].sweep.single_sweep.points.points == [1.0, 2.0] if idx == 1: assert sweeps[0].sweep.single_sweep.points.points == [3.0, 4.0] - client().get_job.assert_called_once() - client().get_job_results.assert_called_once() + client().get_job_async.assert_called_once() + client().get_job_results_async.assert_called_once() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch_no_params(client): # OK to run with no params, it should use empty sweeps for each # circuit. @@ -633,7 +624,7 @@ def test_run_batch_no_params(client): # Validate correct number of params have been created and that they # are empty sweeps. run_context = v2.batch_pb2.BatchRunContext() - client().create_job.call_args[1]['run_context'].Unpack(run_context) + client().create_job_async.call_args[1]['run_context'].Unpack(run_context) assert len(run_context.run_contexts) == 2 for rc in run_context.run_contexts: sweeps = rc.parameter_sweeps @@ -672,7 +663,7 @@ def test_bad_sweep_proto(): program.run_sweep() -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_calibration(client): setup_run_circuit_with_result_(client, _CALIBRATION_RESULTS_V2) @@ -696,7 +687,7 @@ def test_run_calibration(client): assert results[1].error_message == 'Second success' # assert label is correct - client().create_job.assert_called_once_with( + client().create_job_async.assert_called_once_with( project_id='proj', program_id='prog', job_id='job-id', @@ -725,7 +716,7 @@ def test_run_calibration_validation_fails(): ) -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_bad_result_proto(client): result = any_pb2.Any() result.CopyFrom(_RESULTS_V2) @@ -752,14 +743,14 @@ def test_get_program(): assert cg.Engine(project_id='proj').get_program('prog').program_id == 'prog' -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_programs') -def test_list_programs(list_programs): +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_programs_async') +def test_list_programs(list_programs_async): prog1 = quantum.QuantumProgram(name='projects/proj/programs/prog-YBGR48THF3JHERZW200804') prog2 = quantum.QuantumProgram(name='projects/otherproj/programs/prog-V3ZRTV6TTAFNTYJV200804') - list_programs.return_value = [prog1, prog2] + list_programs_async.return_value = [prog1, prog2] result = cg.Engine(project_id='proj').list_programs() - list_programs.assert_called_once_with( + list_programs_async.assert_called_once_with( 'proj', created_after=None, created_before=None, has_labels=None ) assert [(p.program_id, p.project_id, p._program) for p in result] == [ @@ -768,23 +759,23 @@ def test_list_programs(list_programs): ] -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_create_program(client): - client().create_program.return_value = ('prog', quantum.QuantumProgram()) + client().create_program_async.return_value = ('prog', quantum.QuantumProgram()) result = cg.Engine(project_id='proj').create_program(_CIRCUIT, 'prog') - client().create_program.assert_called_once() + client().create_program_async.assert_called_once() assert result.program_id == 'prog' -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs') -def test_list_jobs(list_jobs): +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs_async') +def test_list_jobs(list_jobs_async): job1 = quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job1') job2 = quantum.QuantumJob(name='projects/proj/programs/prog2/jobs/job2') - list_jobs.return_value = [job1, job2] + list_jobs_async.return_value = [job1, job2] ctx = EngineContext() result = cg.Engine(project_id='proj', context=ctx).list_jobs() - list_jobs.assert_called_once_with( + list_jobs_async.assert_called_once_with( 'proj', None, created_after=None, @@ -798,14 +789,14 @@ def test_list_jobs(list_jobs): ] -@mock.patch('cirq_google.engine.engine_client.EngineClient.list_processors') -def test_list_processors(list_processors): +@mock.patch('cirq_google.engine.engine_client.EngineClient.list_processors_async') +def test_list_processors(list_processors_async): processor1 = quantum.QuantumProcessor(name='projects/proj/processors/xmonsim') processor2 = quantum.QuantumProcessor(name='projects/proj/processors/gmonsim') - list_processors.return_value = [processor1, processor2] + list_processors_async.return_value = [processor1, processor2] result = cg.Engine(project_id='proj').list_processors() - list_processors.assert_called_once_with('proj') + list_processors_async.assert_called_once_with('proj') assert [p.processor_id for p in result] == ['xmonsim', 'gmonsim'] @@ -813,7 +804,7 @@ def test_get_processor(): assert cg.Engine(project_id='proj').get_processor('xmonsim').processor_id == 'xmonsim' -@mock.patch('cirq_google.engine.engine_client.EngineClient') +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_sampler(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -827,7 +818,7 @@ def test_sampler(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - assert client().create_program.call_args[0][0] == 'proj' + assert client().create_program_async.call_args[0][0] == 'proj' with cirq.testing.assert_deprecated('sampler', deadline='1.0'): _ = engine.sampler(processor_id='tmp') @@ -836,12 +827,12 @@ def test_sampler(client): @mock.patch('cirq_google.cloud.quantum.QuantumEngineServiceClient') def test_get_engine(build): # Default project id present. - with mock.patch('google.auth.default', lambda: (None, 'project!')): + with mock.patch('google.auth.default', lambda *args, **kwargs: (None, 'project!')): eng = cirq_google.get_engine() assert eng.project_id == 'project!' # Nothing present. - with mock.patch('google.auth.default', lambda: (None, None)): + with mock.patch('google.auth.default', lambda *args, **kwargs: (None, None)): with pytest.raises(EnvironmentError, match='GOOGLE_CLOUD_PROJECT'): _ = cirq_google.get_engine() _ = cirq_google.get_engine('project!') From 4c36675a5e17d900ecb111ddc7aeff082237cd40 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Fri, 22 Apr 2022 09:26:44 -0700 Subject: [PATCH 2/9] wip - xfail tests that use AsyncMock in python < 3.8 --- .../cirq_google/engine/engine_client_test.py | 42 +++++++++++++++++++ .../cirq_google/engine/engine_job_test.py | 24 +++++++++++ .../engine/engine_processor_test.py | 26 ++++++++++++ .../cirq_google/engine/engine_program_test.py | 21 ++++++++++ .../cirq_google/engine/engine_sampler_test.py | 6 +++ cirq-google/cirq_google/engine/engine_test.py | 18 ++++++++ cirq-google/cirq_google/engine/util_test.py | 22 ++++++++++ 7 files changed, 159 insertions(+) create mode 100644 cirq-google/cirq_google/engine/util_test.py diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index 4d3878b5ffb..03432e1031b 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -24,6 +24,7 @@ import duet from cirq_google.engine.engine_client import EngineClient, EngineException +from cirq_google.engine.util_test import uses_async_mock from cirq_google.cloud import quantum @@ -33,6 +34,7 @@ def setup_mock_(client_constructor): return grpc_client +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -99,6 +101,7 @@ def test_create_program(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -118,6 +121,7 @@ def test_get_program(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -135,6 +139,7 @@ def test_list_program(client_constructor): ) +@uses_async_mock @pytest.mark.parametrize( 'expected_filter, created_after, created_before, labels', [ @@ -179,12 +184,14 @@ def test_list_program_filters( assert grpc_client.list_quantum_programs.call_args[0][0].filter == expected_filter +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_program_filters_invalid_type(client_constructor): with pytest.raises(ValueError, match=""): EngineClient().list_programs(project_id='proj', created_before="Unsupported date/time") +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_program_description(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -214,6 +221,7 @@ def test_set_program_description(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_program_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -249,6 +257,7 @@ def test_set_program_labels(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_add_program_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -291,6 +300,7 @@ def test_add_program_labels(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_remove_program_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -331,6 +341,7 @@ def test_remove_program_labels(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -347,6 +358,7 @@ def test_delete_program(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_job(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -478,6 +490,7 @@ def test_create_job(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_job(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -501,6 +514,7 @@ def test_get_job(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_job_description(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -530,6 +544,7 @@ def test_set_job_description(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_job_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -567,6 +582,7 @@ def test_set_job_labels(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_add_job_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -611,6 +627,7 @@ def test_add_job_labels(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_remove_job_labels(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -651,6 +668,7 @@ def test_remove_job_labels(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_job(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -662,6 +680,7 @@ def test_delete_job(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_cancel_job(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -673,6 +692,7 @@ def test_cancel_job(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_job_results(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -687,6 +707,7 @@ def test_job_results(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_jobs(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -709,6 +730,7 @@ def test_list_jobs(client_constructor): ) +@uses_async_mock @pytest.mark.parametrize( 'expected_filter, ' 'created_after, ' @@ -825,6 +847,7 @@ def __aiter__(self): return duet.aiter(self.items) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_processors(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -842,6 +865,7 @@ def test_list_processors(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_processor(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -856,6 +880,7 @@ def test_get_processor(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_calibrations(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -873,6 +898,7 @@ def test_list_calibrations(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_calibration(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -891,6 +917,7 @@ def test_get_calibration(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -909,6 +936,7 @@ def test_get_current_calibration(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration_does_not_exist(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -924,6 +952,7 @@ def test_get_current_calibration_does_not_exist(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration_error(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -935,6 +964,7 @@ def test_get_current_calibration_error(client_constructor): client.get_current_calibration('proj', 'processor0') +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_api_doesnt_retry_not_found_errors(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -946,6 +976,7 @@ def test_api_doesnt_retry_not_found_errors(client_constructor): assert grpc_client.get_quantum_program.call_count == 1 +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_api_retry_5xx_errors(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -957,6 +988,7 @@ def test_api_retry_5xx_errors(client_constructor): assert grpc_client.get_quantum_program.call_count == 3 +@uses_async_mock @mock.patch('duet.sleep', return_value=duet.completed_future(None)) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_api_retry_times(client_constructor, mock_sleep): @@ -972,6 +1004,7 @@ def test_api_retry_times(client_constructor, mock_sleep): assert all(x == y for (x, _), y in zip(mock_sleep.call_args_list, [(0.1,), (0.2,)])) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -998,6 +1031,7 @@ def test_create_reservation(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_cancel_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1017,6 +1051,7 @@ def test_cancel_reservation(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1036,6 +1071,7 @@ def test_delete_reservation(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1055,6 +1091,7 @@ def test_get_reservation(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation_not_found(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1068,6 +1105,7 @@ def test_get_reservation_not_found(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation_exception(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1078,6 +1116,7 @@ def test_get_reservation_exception(client_constructor): client.get_reservation('proj', 'processor0', 'goog') +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1102,6 +1141,7 @@ def test_list_reservation(client_constructor): assert client.list_reservations('proj', 'processor0') == results +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_update_reservation(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1135,6 +1175,7 @@ def test_update_reservation(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_update_reservation_remove_all_users(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1156,6 +1197,7 @@ def test_update_reservation_remove_all_users(client_constructor): ) +@uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_time_slots(client_constructor): grpc_client = setup_mock_(client_constructor) diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index 4694e835001..e0e1d768d3c 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -25,6 +25,7 @@ from cirq_google.engine import util from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext +from cirq_google.engine.util_test import uses_async_mock @pytest.fixture(scope='session', autouse=True) @@ -70,6 +71,7 @@ def test_create_time(): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_update_time(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) @@ -82,6 +84,7 @@ def test_update_time(get_job): get_job.assert_called_once_with('a', 'b', 'steve', False) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_description(get_job): job = cg.EngineJob( @@ -93,6 +96,7 @@ def test_description(get_job): get_job.assert_called_once_with('a', 'b', 'steve', False) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_description_async') def test_set_description(set_job_description): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) @@ -112,6 +116,7 @@ def test_labels(): assert job.labels() == {'t': '1'} +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_labels_async') def test_set_labels(set_job_labels): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) @@ -124,6 +129,7 @@ def test_set_labels(set_job_labels): set_job_labels.assert_called_with('a', 'b', 'steve', {}) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.add_job_labels_async') def test_add_labels(add_job_labels): job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=quantum.QuantumJob(labels={})) @@ -138,6 +144,7 @@ def test_add_labels(add_job_labels): add_job_labels.assert_called_with('a', 'b', 'steve', {'a': '2', 'b': '1'}) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.remove_job_labels_async') def test_remove_labels(remove_job_labels): job = cg.EngineJob( @@ -171,6 +178,7 @@ def test_processor_ids(): assert job.processor_ids() == ['p'] +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_status(get_job): qjob = quantum.QuantumJob( @@ -216,6 +224,7 @@ def test_failure_with_no_error(): assert not job.failure() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) @@ -230,6 +239,7 @@ def test_get_repetitions_and_sweeps(get_job): get_job.assert_called_once_with('a', 'b', 'steve', True) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps_v1(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) @@ -244,6 +254,7 @@ def test_get_repetitions_and_sweeps_v1(get_job): job.get_repetitions_and_sweeps() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_get_repetitions_and_sweeps_unsupported(get_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) @@ -312,6 +323,7 @@ def test_get_calibration(get_calibration): get_calibration.assert_called_once_with('a', 'p', 123) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration_async') def test_calibration__with_no_calibration(get_calibration): job = cg.EngineJob( @@ -329,6 +341,7 @@ def test_calibration__with_no_calibration(get_calibration): assert not get_calibration.called +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_job_async') def test_cancel(cancel_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) @@ -336,6 +349,7 @@ def test_cancel(cancel_job): cancel_job.assert_called_once_with('a', 'b', 'steve') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job_async') def test_delete(delete_job): job = cg.EngineJob('a', 'b', 'steve', EngineContext()) @@ -504,6 +518,7 @@ def test_delete(delete_job): UPDATE_TIME = datetime.datetime.now(tz=datetime.timezone.utc) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results(get_job_results): qjob = quantum.QuantumJob( @@ -520,6 +535,7 @@ def test_results(get_job_results): get_job_results.assert_called_once_with('a', 'b', 'steve') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_iter(get_job_results): qjob = quantum.QuantumJob( @@ -535,6 +551,7 @@ def test_results_iter(get_job_results): assert results[1] == 'q=1010' +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_getitem(get_job_results): qjob = quantum.QuantumJob( @@ -550,6 +567,7 @@ def test_results_getitem(get_job_results): _ = job[2] +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_batched_results(get_job_results): qjob = quantum.QuantumJob( @@ -577,6 +595,7 @@ def test_batched_results(get_job_results): assert str(data[1][1]) == 'q=1001' +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_batched_results_not_a_batch(get_job_results): qjob = quantum.QuantumJob( @@ -589,6 +608,7 @@ def test_batched_results_not_a_batch(get_job_results): job.batched_results() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_results(get_job_results): qjob = quantum.QuantumJob( @@ -608,6 +628,7 @@ def test_calibration_results(get_job_results): assert data[0].metrics['theta'] == {(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)): [0.9999]} +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_defaults(get_job_results): qjob = quantum.QuantumJob( @@ -628,6 +649,7 @@ def test_calibration_defaults(get_job_results): assert len(data[0].metrics) == 0 +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_calibration_results_not_a_calibration(get_job_results): qjob = quantum.QuantumJob( @@ -640,6 +662,7 @@ def test_calibration_results_not_a_calibration(get_job_results): job.calibration_results() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_results_len(get_job_results): qjob = quantum.QuantumJob( @@ -652,6 +675,7 @@ def test_results_len(get_job_results): assert len(job) == 2 +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async') def test_timeout(get_job): qjob = quantum.QuantumJob( diff --git a/cirq-google/cirq_google/engine/engine_processor_test.py b/cirq-google/cirq_google/engine/engine_processor_test.py index 3b058857877..2d73a8d9a7a 100644 --- a/cirq-google/cirq_google/engine/engine_processor_test.py +++ b/cirq-google/cirq_google/engine/engine_processor_test.py @@ -28,6 +28,7 @@ from cirq_google.api import v2 from cirq_google.engine import util from cirq_google.engine.engine import EngineContext +from cirq_google.engine.util_test import uses_async_mock from cirq_google.cloud import quantum @@ -232,6 +233,7 @@ def test_engine_repr(): assert 'the-processor-id' in repr(processor) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') def test_health(get_processor): get_processor.return_value = quantum.QuantumProcessor(health=quantum.QuantumProcessor.Health.OK) @@ -244,6 +246,7 @@ def test_health(get_processor): assert processor.health() == 'OK' +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') def test_expected_down_time(get_processor): processor = cg.EngineProcessor('a', 'p', EngineContext(), _processor=quantum.QuantumProcessor()) @@ -352,6 +355,7 @@ def test_get_missing_device(): _ = processor.get_device(gate_sets=[_GATE_SET]) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations_async') def test_list_calibrations(list_calibrations): list_calibrations.return_value = [_CALIBRATION] @@ -391,6 +395,7 @@ def test_list_calibrations(list_calibrations): list_calibrations.assert_called_with('a', 'p', f'timestamp >= {today_midnight_timestamp}') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_calibrations_async') def test_list_calibrations_old_params(list_calibrations): # Disable pylint warnings for use of deprecated parameters @@ -409,6 +414,7 @@ def test_list_calibrations_old_params(list_calibrations): list_calibrations.assert_called_with('a', 'p', 'timestamp <= 1562600000') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration_async') def test_get_calibration(get_calibration): get_calibration.return_value = _CALIBRATION @@ -419,6 +425,7 @@ def test_get_calibration(get_calibration): get_calibration.assert_called_once_with('a', 'p', 1562544000021) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration_async') def test_current_calibration(get_current_calibration): get_current_calibration.return_value = _CALIBRATION @@ -429,6 +436,7 @@ def test_current_calibration(get_current_calibration): get_current_calibration.assert_called_once_with('a', 'p') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_current_calibration_async') def test_missing_latest_calibration(get_current_calibration): get_current_calibration.return_value = None @@ -437,6 +445,7 @@ def test_missing_latest_calibration(get_current_calibration): get_current_calibration.assert_called_once_with('a', 'p') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.create_reservation_async') def test_create_reservation(create_reservation): name = 'projects/proj/processors/p0/reservations/psherman-wallaby-way' @@ -462,6 +471,7 @@ def test_create_reservation(create_reservation): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation_async') def test_delete_reservation(delete_reservation): name = 'projects/proj/processors/p0/reservations/rid' @@ -477,6 +487,7 @@ def test_delete_reservation(delete_reservation): delete_reservation.assert_called_once_with('proj', 'p0', 'rid') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation_async') def test_cancel_reservation(cancel_reservation): name = 'projects/proj/processors/p0/reservations/rid' @@ -492,6 +503,7 @@ def test_cancel_reservation(cancel_reservation): cancel_reservation.assert_called_once_with('proj', 'p0', 'rid') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') @mock.patch('cirq_google.engine.engine_client.EngineClient.delete_reservation_async') def test_remove_reservation_delete(delete_reservation, get_reservation): @@ -515,6 +527,7 @@ def test_remove_reservation_delete(delete_reservation, get_reservation): delete_reservation.assert_called_once_with('proj', 'p0', 'rid') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') @mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_reservation_async') def test_remove_reservation_cancel(cancel_reservation, get_reservation): @@ -538,6 +551,7 @@ def test_remove_reservation_cancel(cancel_reservation, get_reservation): cancel_reservation.assert_called_once_with('proj', 'p0', 'rid') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_remove_reservation_not_found(get_reservation): get_reservation.return_value = None @@ -551,6 +565,7 @@ def test_remove_reservation_not_found(get_reservation): processor.remove_reservation('rid') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') @mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_remove_reservation_failures(get_reservation, get_processor): @@ -576,6 +591,7 @@ def test_remove_reservation_failures(get_reservation, get_processor): processor.remove_reservation('rid') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_reservation_async') def test_get_reservation(get_reservation): name = 'projects/proj/processors/p0/reservations/rid' @@ -591,6 +607,7 @@ def test_get_reservation(get_reservation): get_reservation.assert_called_once_with('proj', 'p0', 'rid') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.update_reservation_async') def test_update_reservation(update_reservation): name = 'projects/proj/processors/p0/reservations/rid' @@ -610,6 +627,7 @@ def test_update_reservation(update_reservation): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations_async') def test_list_reservation(list_reservations): name = 'projects/proj/processors/p0/reservations/rid' @@ -640,6 +658,7 @@ def test_list_reservation(list_reservations): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') def test_get_schedule(list_time_slots): results = [ @@ -675,6 +694,7 @@ def test_get_schedule(list_time_slots): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') def test_get_schedule_filter_by_time_slot(list_time_slots): results = [ @@ -735,6 +755,7 @@ def wrapper(*args, **kwargs): return wrapper +@uses_async_mock @_allow_deprecated_freezegun @freezegun.freeze_time() @mock.patch('cirq_google.engine.engine_client.EngineClient.list_time_slots_async') @@ -779,6 +800,7 @@ def test_get_schedule_time_filter_behavior(list_time_slots): list_time_slots.assert_called_with('proj', 'p0', f'start_time < {utc_ts}') +@uses_async_mock @_allow_deprecated_freezegun @freezegun.freeze_time() @mock.patch('cirq_google.engine.engine_client.EngineClient.list_reservations_async') @@ -823,6 +845,7 @@ def test_list_reservations_time_filter_behavior(list_reservations): list_reservations.assert_called_with('proj', 'p0', f'start_time < {utc_ts}') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_params(client): client().create_program_async.return_value = ( @@ -871,6 +894,7 @@ def test_run_sweep_params(client): client().get_job_results_async.assert_called_once() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch(client): client().create_program_async.return_value = ( @@ -919,6 +943,7 @@ def test_run_batch(client): client().get_job_results_async.assert_called_once() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_calibration(client): client().create_program_async.return_value = ( @@ -967,6 +992,7 @@ def test_run_calibration(client): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_sampler(client): client().create_program_async.return_value = ( diff --git a/cirq-google/cirq_google/engine/engine_program_test.py b/cirq-google/cirq_google/engine/engine_program_test.py index 9ce307529c5..1887fe8a0c3 100644 --- a/cirq-google/cirq_google/engine/engine_program_test.py +++ b/cirq-google/cirq_google/engine/engine_program_test.py @@ -27,6 +27,7 @@ from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext from cirq_google.engine.result_type import ResultType +from cirq_google.engine.util_test import uses_async_mock _BATCH_PROGRAM_V2 = util.pack_any( @@ -163,6 +164,7 @@ ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') def test_run_sweeps_delegation(create_job_async): create_job_async.return_value = ('steve', quantum.QuantumJob()) @@ -174,6 +176,7 @@ def test_run_sweeps_delegation(create_job_async): assert job._job == quantum.QuantumJob() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') def test_run_batch_delegation(create_job_async): create_job_async.return_value = ('kittens', quantum.QuantumJob()) @@ -185,6 +188,7 @@ def test_run_batch_delegation(create_job_async): assert job._job == quantum.QuantumJob() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') def test_run_calibration_delegation(create_job_async): create_job_async.return_value = ('dogs', quantum.QuantumJob()) @@ -201,6 +205,7 @@ def test_run_calibration_no_processors(create_job_async): _ = program.run_calibration(job_id='spot') +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') def test_run_batch_no_sweeps(create_job_async): # Running with no sweeps is fine. Uses program's batch size to create @@ -242,6 +247,7 @@ def test_run_in_batch_mode(): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') @mock.patch('cirq_google.engine.engine_client.EngineClient.create_job_async') def test_run_delegation(create_job_async, get_results_async): @@ -297,6 +303,7 @@ def test_run_delegation(create_job_async, get_results_async): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs_async') def test_list_jobs(list_jobs_async): job1 = quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job1') @@ -341,6 +348,7 @@ def test_create_time(): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') def test_update_time(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) @@ -353,6 +361,7 @@ def test_update_time(get_program_async): get_program_async.assert_called_once_with('a', 'b', False) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') def test_description(get_program_async): program = cg.EngineProgram( @@ -365,6 +374,7 @@ def test_description(get_program_async): get_program_async.assert_called_once_with('a', 'b', False) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_description_async') def test_set_description(set_program_description_async): program = cg.EngineProgram('a', 'b', EngineContext()) @@ -384,6 +394,7 @@ def test_labels(): assert program.labels() == {'t': '1'} +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.set_program_labels_async') def test_set_labels(set_program_labels_async): program = cg.EngineProgram('a', 'b', EngineContext()) @@ -396,6 +407,7 @@ def test_set_labels(set_program_labels_async): set_program_labels_async.assert_called_with('a', 'b', {}) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.add_program_labels_async') def test_add_labels(add_program_labels_async): program = cg.EngineProgram( @@ -412,6 +424,7 @@ def test_add_labels(add_program_labels_async): add_program_labels_async.assert_called_with('a', 'b', {'a': '2', 'b': '1'}) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.remove_program_labels_async') def test_remove_labels(remove_program_labels_async): program = cg.EngineProgram( @@ -428,6 +441,7 @@ def test_remove_labels(remove_program_labels_async): remove_program_labels_async.assert_called_with('a', 'b', ['a', 'b', 'c']) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') def test_get_circuit_v1(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) @@ -439,6 +453,7 @@ def test_get_circuit_v1(get_program_async): program.get_circuit() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') def test_get_circuit_v2(get_program_async): circuit = cirq.Circuit( @@ -451,6 +466,7 @@ def test_get_circuit_v2(get_program_async): get_program_async.assert_called_once_with('a', 'b', True) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') def test_get_circuit_batch(get_program_async): circuit = cirq.Circuit( @@ -467,6 +483,7 @@ def test_get_circuit_batch(get_program_async): get_program_async.assert_called_once_with('a', 'b', True) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') def test_get_batch_size(get_program_async): # Has to fetch from engine if not _program specified. @@ -502,6 +519,7 @@ def mock_grpc_client(): yield _fixture +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') def test_get_circuit_v2_unknown_gateset(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) @@ -515,6 +533,7 @@ def test_get_circuit_v2_unknown_gateset(get_program_async): program.get_circuit() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async') def test_get_circuit_unsupported_program_type(get_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) @@ -526,6 +545,7 @@ def test_get_circuit_unsupported_program_type(get_program_async): program.get_circuit() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.delete_program_async') def test_delete(delete_program_async): program = cg.EngineProgram('a', 'b', EngineContext()) @@ -536,6 +556,7 @@ def test_delete(delete_program_async): delete_program_async.assert_called_with('a', 'b', delete_jobs=True) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job_async') def test_delete_jobs(delete_job_async): program = cg.EngineProgram('a', 'b', EngineContext()) diff --git a/cirq-google/cirq_google/engine/engine_sampler_test.py b/cirq-google/cirq_google/engine/engine_sampler_test.py index 94e4e27c784..14199bcf5d2 100644 --- a/cirq-google/cirq_google/engine/engine_sampler_test.py +++ b/cirq-google/cirq_google/engine/engine_sampler_test.py @@ -19,8 +19,10 @@ import cirq import cirq_google as cg import cirq_google.cloud.quantum +from cirq_google.engine.util_test import uses_async_mock +@uses_async_mock @pytest.mark.parametrize('circuit', [cirq.Circuit(), cirq.FrozenCircuit()]) def test_run_circuit(circuit): engine = mock.Mock() @@ -33,6 +35,7 @@ def test_run_circuit(circuit): ) +@uses_async_mock def test_run_engine_program(): engine = mock.Mock() engine.run_sweep_async = mock.AsyncMock() @@ -45,6 +48,7 @@ def test_run_engine_program(): engine.run_sweep_async.assert_not_called() +@uses_async_mock def test_run_batch(): engine = mock.Mock() engine.run_batch_async = mock.AsyncMock() @@ -62,6 +66,7 @@ def test_run_batch(): ) +@uses_async_mock def test_run_batch_identical_repetitions(): engine = mock.Mock() engine.run_batch_async = mock.AsyncMock() @@ -93,6 +98,7 @@ def test_run_batch_bad_number_of_repetitions(): sampler.run_batch(circuits, params_list, [5, 5, 5]) +@uses_async_mock def test_run_batch_differing_repetitions(): engine = mock.Mock() job = mock.Mock() diff --git a/cirq-google/cirq_google/engine/engine_test.py b/cirq-google/cirq_google/engine/engine_test.py index 8dc3d6bf021..8bf1c97406d 100644 --- a/cirq-google/cirq_google/engine/engine_test.py +++ b/cirq-google/cirq_google/engine/engine_test.py @@ -29,6 +29,7 @@ from cirq_google.engine import util from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext +from cirq_google.engine.util_test import uses_async_mock _CIRCUIT = cirq.Circuit( @@ -351,6 +352,7 @@ def setup_run_circuit_with_result_(client, result): client().get_job_results_async.return_value = quantum.QuantumResult(result=result) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit(client): setup_run_circuit_with_result_(client, _A_RESULT) @@ -393,6 +395,7 @@ def test_unsupported_program_type(): engine.run(program="this isn't even the right type of thing!") +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_failed(client): client().create_program_async.return_value = ( @@ -423,6 +426,7 @@ def test_run_circuit_failed(client): engine.run(program=_CIRCUIT) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_failed_missing_processor_name(client): client().create_program_async.return_value = ( @@ -452,6 +456,7 @@ def test_run_circuit_failed_missing_processor_name(client): engine.run(program=_CIRCUIT) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_cancelled(client): client().create_program_async.return_value = ( @@ -475,6 +480,7 @@ def test_run_circuit_cancelled(client): engine.run(program=_CIRCUIT) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_timeout(client): client().create_program_async.return_value = ( @@ -496,6 +502,7 @@ def test_run_circuit_timeout(client): engine.run(program=_CIRCUIT) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_params(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -525,6 +532,7 @@ def test_run_sweep_params(client): client().get_job_results_async.assert_called_once() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_multiple_times(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -556,6 +564,7 @@ def test_run_multiple_times(client): assert client().get_job_results_async.call_count == 2 +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_v2(client): setup_run_circuit_with_result_(client, _RESULTS_V2) @@ -580,6 +589,7 @@ def test_run_sweep_v2(client): client().get_job_results_async.assert_called_once() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch(client): setup_run_circuit_with_result_(client, _BATCH_RESULTS_V2) @@ -614,6 +624,7 @@ def test_run_batch(client): client().get_job_results_async.assert_called_once() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch_no_params(client): # OK to run with no params, it should use empty sweeps for each @@ -663,6 +674,7 @@ def test_bad_sweep_proto(): program.run_sweep() +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_calibration(client): setup_run_circuit_with_result_(client, _CALIBRATION_RESULTS_V2) @@ -716,6 +728,7 @@ def test_run_calibration_validation_fails(): ) +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_bad_result_proto(client): result = any_pb2.Any() @@ -743,6 +756,7 @@ def test_get_program(): assert cg.Engine(project_id='proj').get_program('prog').program_id == 'prog' +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_programs_async') def test_list_programs(list_programs_async): prog1 = quantum.QuantumProgram(name='projects/proj/programs/prog-YBGR48THF3JHERZW200804') @@ -759,6 +773,7 @@ def test_list_programs(list_programs_async): ] +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_create_program(client): client().create_program_async.return_value = ('prog', quantum.QuantumProgram()) @@ -767,6 +782,7 @@ def test_create_program(client): assert result.program_id == 'prog' +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_jobs_async') def test_list_jobs(list_jobs_async): job1 = quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job1') @@ -789,6 +805,7 @@ def test_list_jobs(list_jobs_async): ] +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient.list_processors_async') def test_list_processors(list_processors_async): processor1 = quantum.QuantumProcessor(name='projects/proj/processors/xmonsim') @@ -804,6 +821,7 @@ def test_get_processor(): assert cg.Engine(project_id='proj').get_processor('xmonsim').processor_id == 'xmonsim' +@uses_async_mock @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_sampler(client): setup_run_circuit_with_result_(client, _RESULTS) diff --git a/cirq-google/cirq_google/engine/util_test.py b/cirq-google/cirq_google/engine/util_test.py new file mode 100644 index 00000000000..fcf7516b82f --- /dev/null +++ b/cirq-google/cirq_google/engine/util_test.py @@ -0,0 +1,22 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +# Annotation for tests that use unittest.mock.AsyncMock, added in python 3.8. +# Tests using AsyncMock are expected to fail in 3.6 and 3.7. +# See: https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock +uses_async_mock = pytest.mark.xfail(sys.version_info < (3, 8, 0), reason='') From 62babd3fee82806cd3cdec00c76943f06b5d1b42 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Fri, 6 May 2022 14:08:24 -0700 Subject: [PATCH 3/9] Instantiate AsyncioExecutor lazily --- cirq-google/cirq_google/engine/engine_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 7c002faa99f..79c4560e5ec 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -105,7 +105,10 @@ def __init__( service_args = {} self._service_args = service_args - self._executor = AsyncioExecutor() + + @cached_property + def _executor(self) -> AsyncioExecutor: + return AsyncioExecutor() @cached_property def grpc_client(self) -> quantum.QuantumEngineServiceAsyncClient: From a6d7fff2589e62a4c94bf6106e7f574480fab903 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 14 Jun 2022 13:58:29 -0700 Subject: [PATCH 4/9] Revert changes to engine_sampler.py and tests --- .../cirq_google/engine/engine_sampler.py | 18 +++++------ .../cirq_google/engine/engine_sampler_test.py | 31 ++++++------------- 2 files changed, 16 insertions(+), 33 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_sampler.py b/cirq-google/cirq_google/engine/engine_sampler.py index bb6971399fc..1b4e927b957 100644 --- a/cirq-google/cirq_google/engine/engine_sampler.py +++ b/cirq-google/cirq_google/engine/engine_sampler.py @@ -51,28 +51,26 @@ def __init__( self._processor_ids = [processor_id] if isinstance(processor_id, str) else processor_id self._engine = engine - async def run_sweep_async( + def run_sweep( self, program: Union[cirq.AbstractCircuit, 'cirq_google.EngineProgram'], params: cirq.Sweepable, repetitions: int = 1, ) -> Sequence[cirq.Result]: if isinstance(program, engine.EngineProgram): - job = await program.run_sweep_async( + job = program.run_sweep( params=params, repetitions=repetitions, processor_ids=self._processor_ids ) else: - job = await self._engine.run_sweep_async( + job = self._engine.run_sweep( program=program, params=params, repetitions=repetitions, processor_ids=self._processor_ids, ) - return await job.results_async() + return job.results() - run_sweep = duet.sync(run_sweep_async) - - async def run_batch_async( + def run_batch( self, programs: Sequence[cirq.AbstractCircuit], params_list: Optional[List[cirq.Sweepable]] = None, @@ -96,18 +94,16 @@ async def run_batch_async( # All repetitions are the same so batching can be done efficiently if isinstance(repetitions, List): repetitions = repetitions[0] - job = await self._engine.run_batch_async( + job = self._engine.run_batch( programs=programs, params_list=params_list, repetitions=repetitions, processor_ids=self._processor_ids, ) - return await job.batched_results_async() + return job.batched_results() # Varying number of repetitions so no speedup return super().run_batch(programs, params_list, repetitions) - run_batch = duet.sync(run_batch_async) - @property def engine(self) -> 'cirq_google.Engine': return self._engine diff --git a/cirq-google/cirq_google/engine/engine_sampler_test.py b/cirq-google/cirq_google/engine/engine_sampler_test.py index f6d1c69b914..4f5067052ac 100644 --- a/cirq-google/cirq_google/engine/engine_sampler_test.py +++ b/cirq-google/cirq_google/engine/engine_sampler_test.py @@ -19,45 +19,37 @@ import cirq import cirq_google as cg import cirq_google.cloud.quantum -from cirq_google.engine.util_test import uses_async_mock -@uses_async_mock @pytest.mark.parametrize('circuit', [cirq.Circuit(), cirq.FrozenCircuit()]) def test_run_circuit(circuit): engine = mock.Mock() - engine.run_sweep_async = mock.AsyncMock() with cirq.testing.assert_deprecated( 'Use cirq_google.ProcessorSampler instead.', deadline='v0.16' ): sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp') params = [cirq.ParamResolver({'a': 1})] sampler.run_sweep(circuit, params, 5) - engine.run_sweep_async.assert_called_with( + engine.run_sweep.assert_called_with( params=params, processor_ids=['tmp'], program=circuit, repetitions=5 ) -@uses_async_mock def test_run_engine_program(): engine = mock.Mock() - engine.run_sweep_async = mock.AsyncMock() with cirq.testing.assert_deprecated( 'Use cirq_google.ProcessorSampler instead.', deadline='v0.16' ): sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp') program = mock.Mock(spec=cg.EngineProgram) - program.run_sweep_async = mock.AsyncMock() params = [cirq.ParamResolver({'a': 1})] sampler.run_sweep(program, params, 5) - program.run_sweep_async.assert_called_with(params=params, processor_ids=['tmp'], repetitions=5) - engine.run_sweep_async.assert_not_called() + program.run_sweep.assert_called_with(params=params, processor_ids=['tmp'], repetitions=5) + engine.run_sweep.assert_not_called() -@uses_async_mock def test_run_batch(): engine = mock.Mock() - engine.run_batch_async = mock.AsyncMock() with cirq.testing.assert_deprecated( 'Use cirq_google.ProcessorSampler instead.', deadline='v0.16' ): @@ -70,15 +62,13 @@ def test_run_batch(): circuits = [circuit1, circuit2] params_list = [params1, params2] sampler.run_batch(circuits, params_list, 5) - engine.run_batch_async.assert_called_with( + engine.run_batch.assert_called_with( params_list=params_list, processor_ids=['tmp'], programs=circuits, repetitions=5 ) -@uses_async_mock def test_run_batch_identical_repetitions(): engine = mock.Mock() - engine.run_batch_async = mock.AsyncMock() with cirq.testing.assert_deprecated( 'Use cirq_google.ProcessorSampler instead.', deadline='v0.16' ): @@ -91,7 +81,7 @@ def test_run_batch_identical_repetitions(): circuits = [circuit1, circuit2] params_list = [params1, params2] sampler.run_batch(circuits, params_list, [5, 5]) - engine.run_batch_async.assert_called_with( + engine.run_batch.assert_called_with( params_list=params_list, processor_ids=['tmp'], programs=circuits, repetitions=5 ) @@ -113,14 +103,11 @@ def test_run_batch_bad_number_of_repetitions(): sampler.run_batch(circuits, params_list, [5, 5, 5]) -@uses_async_mock def test_run_batch_differing_repetitions(): engine = mock.Mock() job = mock.Mock() - job.results_async = mock.AsyncMock() - job.results_async.return_value = [] - engine.run_sweep_async = mock.AsyncMock() - engine.run_sweep_async.return_value = job + job.results.return_value = [] + engine.run_sweep.return_value = job with cirq.testing.assert_deprecated( 'Use cirq_google.ProcessorSampler instead.', deadline='v0.16' ): @@ -134,10 +121,10 @@ def test_run_batch_differing_repetitions(): params_list = [params1, params2] repetitions = [1, 2] sampler.run_batch(circuits, params_list, repetitions) - engine.run_sweep_async.assert_called_with( + engine.run_sweep.assert_called_with( params=params2, processor_ids=['tmp'], program=circuit2, repetitions=2 ) - engine.run_batch_async.assert_not_called() + engine.run_batch.assert_not_called() def test_engine_sampler_engine_property(): From 757437cae023b377d6a4577dfe779683a146d6b3 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 14 Jun 2022 14:05:44 -0700 Subject: [PATCH 5/9] Remove unused import --- cirq-google/cirq_google/engine/engine_sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_sampler.py b/cirq-google/cirq_google/engine/engine_sampler.py index 1b4e927b957..d9d2011b20b 100644 --- a/cirq-google/cirq_google/engine/engine_sampler.py +++ b/cirq-google/cirq_google/engine/engine_sampler.py @@ -14,8 +14,6 @@ from typing import List, Optional, Sequence, TYPE_CHECKING, Union -import duet - import cirq from cirq_google import engine from cirq_google.engine import util From 09e0eab3770249250df7c22ec2a6aaa493e3a37a Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 14 Jun 2022 19:00:27 -0700 Subject: [PATCH 6/9] Add docstring for AsyncioExecutor --- cirq-google/cirq_google/engine/engine_client.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 79c4560e5ec..bd9ee805b47 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -54,6 +54,12 @@ def __init__(self, message): class AsyncioExecutor: + """Runs asyncio coroutines in a thread, exposes the results as duet futures. + + This lets us bridge between an asyncio event loop (which is what async grpc + code uses) and duet (which is what cirq uses for asynchrony). + """ + def __init__(self) -> None: loop_future: duet.AwaitableFuture[asyncio.AbstractEventLoop] = duet.AwaitableFuture() thread = threading.Thread(target=asyncio.run, args=(self._main(loop_future),), daemon=True) @@ -68,7 +74,14 @@ async def _main(loop_future: duet.AwaitableFuture) -> None: await asyncio.sleep(1) def submit(self, func, *args, **kw) -> duet.AwaitableFuture: - """Dispatch the given function to be run in a duet coroutine.""" + """Dispatch the given function to be run in an asyncio coroutine. + + Args: + func: asyncio function which will be run in a separate thread. + Will be called with *args and **kw and should return an asyncio. + *args: Positional args to pass to func. + **kw: Keyword args to pass to func. + """ future = asyncio.run_coroutine_threadsafe(func(*args, **kw), self.loop) return duet.AwaitableFuture.wrap(future) From c0bbc68b0ea4d9ea4d868d3edd3b764e0eed286f Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 14 Jun 2022 19:28:35 -0700 Subject: [PATCH 7/9] Revert engine.py changes --- cirq-google/cirq_google/engine/engine.py | 46 +++++++++++------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine.py b/cirq-google/cirq_google/engine/engine.py index ac98a957d8e..f779c526b32 100644 --- a/cirq-google/cirq_google/engine/engine.py +++ b/cirq-google/cirq_google/engine/engine.py @@ -209,7 +209,7 @@ def __str__(self) -> str: return f'Engine(project_id={self.project_id!r})' @util.deprecated_gate_set_parameter - async def run_async( + def run( self, program: cirq.AbstractCircuit, program_id: Optional[str] = None, @@ -255,25 +255,23 @@ async def run_async( Raises: ValueError: If no gate set is provided. """ - job = await self.run_sweep_async( - program=program, - program_id=program_id, - job_id=job_id, - params=[param_resolver], - repetitions=repetitions, - processor_ids=processor_ids, - program_description=program_description, - program_labels=program_labels, - job_description=job_description, - job_labels=job_labels, - ) - results = await job.results_async() - return results[0] - - run = duet.sync(run_async) + return list( + self.run_sweep( + program=program, + program_id=program_id, + job_id=job_id, + params=[param_resolver], + repetitions=repetitions, + processor_ids=processor_ids, + program_description=program_description, + program_labels=program_labels, + job_description=job_description, + job_labels=job_labels, + ) + )[0] @util.deprecated_gate_set_parameter - async def run_sweep_async( + def run_sweep( self, program: cirq.AbstractCircuit, program_id: Optional[str] = None, @@ -323,10 +321,10 @@ async def run_sweep_async( Raises: ValueError: If no gate set is provided. """ - engine_program = await self.create_program_async( + engine_program = self.create_program( program, program_id, description=program_description, labels=program_labels ) - return await engine_program.run_sweep_async( + return engine_program.run_sweep( job_id=job_id, params=params, repetitions=repetitions, @@ -335,10 +333,8 @@ async def run_sweep_async( labels=job_labels, ) - run_sweep = duet.sync(run_sweep_async) - @util.deprecated_gate_set_parameter - async def run_batch_async( + def run_batch( self, programs: Sequence[cirq.AbstractCircuit], program_id: Optional[str] = None, @@ -410,7 +406,7 @@ async def run_batch_async( engine_program = self.create_batch_program( programs, program_id, description=program_description, labels=program_labels ) - return await engine_program.run_batch_async( + return engine_program.run_batch( job_id=job_id, params_list=params_list, repetitions=repetitions, @@ -419,8 +415,6 @@ async def run_batch_async( labels=job_labels, ) - run_batch = duet.sync(run_batch_async) - @util.deprecated_gate_set_parameter def run_calibration( self, From 1ab5b83e17fc2e54e7963ba0fc495a7ea81b8e90 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Wed, 15 Jun 2022 08:06:36 -0700 Subject: [PATCH 8/9] Rename util_test.py -> test_utils.py --- cirq-google/cirq_google/engine/engine_client_test.py | 2 +- cirq-google/cirq_google/engine/engine_job_test.py | 2 +- cirq-google/cirq_google/engine/engine_processor_test.py | 2 +- cirq-google/cirq_google/engine/engine_program_test.py | 2 +- cirq-google/cirq_google/engine/engine_test.py | 2 +- cirq-google/cirq_google/engine/{util_test.py => test_utils.py} | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) rename cirq-google/cirq_google/engine/{util_test.py => test_utils.py} (91%) diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index 03432e1031b..357e7dbe82b 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -24,7 +24,7 @@ import duet from cirq_google.engine.engine_client import EngineClient, EngineException -from cirq_google.engine.util_test import uses_async_mock +from cirq_google.engine.test_utils import uses_async_mock from cirq_google.cloud import quantum diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index e0e1d768d3c..8a6b45b3fda 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -25,7 +25,7 @@ from cirq_google.engine import util from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext -from cirq_google.engine.util_test import uses_async_mock +from cirq_google.engine.test_utils import uses_async_mock @pytest.fixture(scope='session', autouse=True) diff --git a/cirq-google/cirq_google/engine/engine_processor_test.py b/cirq-google/cirq_google/engine/engine_processor_test.py index 2d73a8d9a7a..bd832db0bf3 100644 --- a/cirq-google/cirq_google/engine/engine_processor_test.py +++ b/cirq-google/cirq_google/engine/engine_processor_test.py @@ -28,7 +28,7 @@ from cirq_google.api import v2 from cirq_google.engine import util from cirq_google.engine.engine import EngineContext -from cirq_google.engine.util_test import uses_async_mock +from cirq_google.engine.test_utils import uses_async_mock from cirq_google.cloud import quantum diff --git a/cirq-google/cirq_google/engine/engine_program_test.py b/cirq-google/cirq_google/engine/engine_program_test.py index 1887fe8a0c3..8bc54ae8e6c 100644 --- a/cirq-google/cirq_google/engine/engine_program_test.py +++ b/cirq-google/cirq_google/engine/engine_program_test.py @@ -27,7 +27,7 @@ from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext from cirq_google.engine.result_type import ResultType -from cirq_google.engine.util_test import uses_async_mock +from cirq_google.engine.test_utils import uses_async_mock _BATCH_PROGRAM_V2 = util.pack_any( diff --git a/cirq-google/cirq_google/engine/engine_test.py b/cirq-google/cirq_google/engine/engine_test.py index 21e59ee25e4..d01e47f1db5 100644 --- a/cirq-google/cirq_google/engine/engine_test.py +++ b/cirq-google/cirq_google/engine/engine_test.py @@ -29,7 +29,7 @@ from cirq_google.engine import util from cirq_google.cloud import quantum from cirq_google.engine.engine import EngineContext -from cirq_google.engine.util_test import uses_async_mock +from cirq_google.engine.test_utils import uses_async_mock _CIRCUIT = cirq.Circuit( diff --git a/cirq-google/cirq_google/engine/util_test.py b/cirq-google/cirq_google/engine/test_utils.py similarity index 91% rename from cirq-google/cirq_google/engine/util_test.py rename to cirq-google/cirq_google/engine/test_utils.py index fcf7516b82f..9bdc5b48596 100644 --- a/cirq-google/cirq_google/engine/util_test.py +++ b/cirq-google/cirq_google/engine/test_utils.py @@ -17,6 +17,6 @@ import pytest # Annotation for tests that use unittest.mock.AsyncMock, added in python 3.8. -# Tests using AsyncMock are expected to fail in 3.6 and 3.7. +# Tests using AsyncMock are expected to fail in earlier versions of python. # See: https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock uses_async_mock = pytest.mark.xfail(sys.version_info < (3, 8, 0), reason='') From 4bee996bdb71e019839201d82552e1a4928ab122 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Wed, 15 Jun 2022 22:43:38 -0700 Subject: [PATCH 9/9] Revert change to engine_sampler_test --- cirq-google/cirq_google/engine/engine_sampler_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-google/cirq_google/engine/engine_sampler_test.py b/cirq-google/cirq_google/engine/engine_sampler_test.py index 4f5067052ac..427d3d73aac 100644 --- a/cirq-google/cirq_google/engine/engine_sampler_test.py +++ b/cirq-google/cirq_google/engine/engine_sampler_test.py @@ -144,6 +144,6 @@ def test_get_engine_sampler_explicit_project_id(): def test_get_engine_sampler(): with mock.patch.object(cirq_google.cloud.quantum, 'QuantumEngineServiceClient', autospec=True): - with mock.patch('google.auth.default', lambda *args, **kwargs: (None, 'myproj')): + with mock.patch('google.auth.default', lambda: (None, 'myproj')): sampler = cg.get_engine_sampler(processor_id='hi mom') assert hasattr(sampler, 'run_sweep')