diff --git a/cirq-google/cirq_google/devices/serializable_device.py b/cirq-google/cirq_google/devices/serializable_device.py index 84758afc3b6..c5214cb3bb7 100644 --- a/cirq-google/cirq_google/devices/serializable_device.py +++ b/cirq-google/cirq_google/devices/serializable_device.py @@ -13,7 +13,7 @@ # limitations under the License. """Device object for converting from device specification protos""" -from typing import Any, Callable, cast, Dict, Iterable, Optional, List, Set, Tuple, Type +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import cirq from cirq_google.serialization import serializable_gate_set from cirq_google.api import v2 @@ -63,6 +63,9 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ +_GateOrFrozenCircuitTypes = Union[Type[cirq.Gate], Type[cirq.FrozenCircuit]] + + class SerializableDevice(cirq.Device): """Device object generated from a device specification proto. @@ -79,7 +82,9 @@ class SerializableDevice(cirq.Device): """ def __init__( - self, qubits: List[cirq.Qid], gate_definitions: Dict[Type[cirq.Gate], List[_GateDefinition]] + self, + qubits: List[cirq.Qid], + gate_definitions: Dict[_GateOrFrozenCircuitTypes, List[_GateDefinition]], ): """Constructor for SerializableDevice using python objects. @@ -93,6 +98,7 @@ def __init__( """ self.qubits = qubits self.gate_definitions = gate_definitions + has_subcircuit_support: bool = cirq.FrozenCircuit in gate_definitions self._metadata = cirq.GridDeviceMetadata( qubit_pairs=[ (pair[0], pair[1]) @@ -103,12 +109,9 @@ def __init__( if len(pair) == 2 and pair[0] < pair[1] ], gateset=cirq.Gateset( - *[ - g - for g in gate_definitions.keys() - if isinstance(g, (cirq.Gate, type(cirq.Gate))) - ], + *(g for g in gate_definitions.keys() if issubclass(g, cirq.Gate)), cirq.GlobalPhaseGate, + unroll_circuit_op=has_subcircuit_support, ), gate_durations=None, ) @@ -174,7 +177,7 @@ def from_proto( ) # Loop through serializers and map gate_definitions to type - gates_by_type: Dict[Type[cirq.Gate], List[_GateDefinition]] = {} + gates_by_type: Dict[_GateOrFrozenCircuitTypes, List[_GateDefinition]] = {} for gate_set in gate_sets: for internal_type in gate_set.supported_internal_types(): for serializer in gate_set.serializers[internal_type]: diff --git a/cirq-google/cirq_google/devices/serializable_device_test.py b/cirq-google/cirq_google/devices/serializable_device_test.py index 4fa11b496aa..1ed37823877 100644 --- a/cirq-google/cirq_google/devices/serializable_device_test.py +++ b/cirq-google/cirq_google/devices/serializable_device_test.py @@ -398,6 +398,16 @@ def test_serializable_device_str_named_qubits(): assert device.__class__.__name__ in str(device) +def test_serializable_device_gate_definitions_filter(): + """Ignore items in gate_definitions dictionary with invalid keys.""" + device = cg.SerializableDevice( + qubits=[cirq.NamedQubit('a'), cirq.NamedQubit('b')], + gate_definitions={cirq.FSimGate: [], cirq.NoiseModel: []}, + ) + # Two gates for cirq.FSimGate and the cirq.GlobalPhaseGate default + assert len(device.metadata.gateset.gates) == 2 + + def test_sycamore23_str(): assert ( str(cg.Sycamore23) diff --git a/cirq-google/cirq_google/serialization/serializable_gate_set.py b/cirq-google/cirq_google/serialization/serializable_gate_set.py index 0c4e92506c1..3a76b729ea2 100644 --- a/cirq-google/cirq_google/serialization/serializable_gate_set.py +++ b/cirq-google/cirq_google/serialization/serializable_gate_set.py @@ -21,6 +21,9 @@ from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs +_GateOrFrozenCircuitTypes = Union[Type[cirq.Gate], Type[cirq.FrozenCircuit]] + + class SerializableGateSet(serializer.Serializer): """A class for serializing and deserializing programs and operations. @@ -47,7 +50,7 @@ def __init__( forms of gates or circuits into Operations. """ super().__init__(gate_set_name) - self.serializers: Dict[Type, List[op_serializer.OpSerializer]] = {} + self.serializers: Dict[_GateOrFrozenCircuitTypes, List[op_serializer.OpSerializer]] = {} for s in serializers: self.serializers.setdefault(s.internal_type, []).append(s) self.deserializers = {d.serialized_id: d for d in deserializers} @@ -77,7 +80,7 @@ def with_added_types( deserializers=[*self.deserializers.values(), *deserializers], ) - def supported_internal_types(self) -> Tuple: + def supported_internal_types(self) -> Tuple[_GateOrFrozenCircuitTypes, ...]: return tuple(self.serializers.keys()) def is_supported(self, op_tree: cirq.OP_TREE) -> bool: @@ -194,8 +197,8 @@ def serialize_gate_op( if gate_type_mro in self.serializers: # Check each serializer in turn, if serializer proto returns # None, then skip. - for serializer in self.serializers[gate_type_mro]: - proto_msg = serializer.to_proto( + for mro_serializer in self.serializers[gate_type_mro]: + proto_msg = mro_serializer.to_proto( op, msg, arg_function_language=arg_function_language,