Skip to content

Commit e14f9ad

Browse files
authored
Fix type check of SerializableDevice gate_definitions (#5447)
The type(cirq.Gate) is cirq.ABCMetaImplementAnyOneOf which could match unrelated instances. Use subclass check instead and remove `isinstance()` check which should never happen. Clarify typing of the `gate_definitions` dictionary keys in SeriazableDevice. Finally, turn off Gateset unroll_circuit_op flag when gate_definitions indicate device does not support subcircuits.
1 parent 25388af commit e14f9ad

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

cirq-google/cirq_google/devices/serializable_device.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Device object for converting from device specification protos"""
1515

16-
from typing import Any, Callable, cast, Dict, Iterable, Optional, List, Set, Tuple, Type
16+
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
1717
import cirq
1818
from cirq_google.serialization import serializable_gate_set
1919
from cirq_google.api import v2
@@ -63,6 +63,9 @@ def __eq__(self, other):
6363
return self.__dict__ == other.__dict__
6464

6565

66+
_GateOrFrozenCircuitTypes = Union[Type[cirq.Gate], Type[cirq.FrozenCircuit]]
67+
68+
6669
class SerializableDevice(cirq.Device):
6770
"""Device object generated from a device specification proto.
6871
@@ -79,7 +82,9 @@ class SerializableDevice(cirq.Device):
7982
"""
8083

8184
def __init__(
82-
self, qubits: List[cirq.Qid], gate_definitions: Dict[Type[cirq.Gate], List[_GateDefinition]]
85+
self,
86+
qubits: List[cirq.Qid],
87+
gate_definitions: Dict[_GateOrFrozenCircuitTypes, List[_GateDefinition]],
8388
):
8489
"""Constructor for SerializableDevice using python objects.
8590
@@ -93,6 +98,7 @@ def __init__(
9398
"""
9499
self.qubits = qubits
95100
self.gate_definitions = gate_definitions
101+
has_subcircuit_support: bool = cirq.FrozenCircuit in gate_definitions
96102
self._metadata = cirq.GridDeviceMetadata(
97103
qubit_pairs=[
98104
(pair[0], pair[1])
@@ -103,12 +109,9 @@ def __init__(
103109
if len(pair) == 2 and pair[0] < pair[1]
104110
],
105111
gateset=cirq.Gateset(
106-
*[
107-
g
108-
for g in gate_definitions.keys()
109-
if isinstance(g, (cirq.Gate, type(cirq.Gate)))
110-
],
112+
*(g for g in gate_definitions.keys() if issubclass(g, cirq.Gate)),
111113
cirq.GlobalPhaseGate,
114+
unroll_circuit_op=has_subcircuit_support,
112115
),
113116
gate_durations=None,
114117
)
@@ -174,7 +177,7 @@ def from_proto(
174177
)
175178

176179
# Loop through serializers and map gate_definitions to type
177-
gates_by_type: Dict[Type[cirq.Gate], List[_GateDefinition]] = {}
180+
gates_by_type: Dict[_GateOrFrozenCircuitTypes, List[_GateDefinition]] = {}
178181
for gate_set in gate_sets:
179182
for internal_type in gate_set.supported_internal_types():
180183
for serializer in gate_set.serializers[internal_type]:

cirq-google/cirq_google/devices/serializable_device_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,16 @@ def test_serializable_device_str_named_qubits():
398398
assert device.__class__.__name__ in str(device)
399399

400400

401+
def test_serializable_device_gate_definitions_filter():
402+
"""Ignore items in gate_definitions dictionary with invalid keys."""
403+
device = cg.SerializableDevice(
404+
qubits=[cirq.NamedQubit('a'), cirq.NamedQubit('b')],
405+
gate_definitions={cirq.FSimGate: [], cirq.NoiseModel: []},
406+
)
407+
# Two gates for cirq.FSimGate and the cirq.GlobalPhaseGate default
408+
assert len(device.metadata.gateset.gates) == 2
409+
410+
401411
def test_sycamore23_str():
402412
assert (
403413
str(cg.Sycamore23)

cirq-google/cirq_google/serialization/serializable_gate_set.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs
2222

2323

24+
_GateOrFrozenCircuitTypes = Union[Type[cirq.Gate], Type[cirq.FrozenCircuit]]
25+
26+
2427
class SerializableGateSet(serializer.Serializer):
2528
"""A class for serializing and deserializing programs and operations.
2629
@@ -47,7 +50,7 @@ def __init__(
4750
forms of gates or circuits into Operations.
4851
"""
4952
super().__init__(gate_set_name)
50-
self.serializers: Dict[Type, List[op_serializer.OpSerializer]] = {}
53+
self.serializers: Dict[_GateOrFrozenCircuitTypes, List[op_serializer.OpSerializer]] = {}
5154
for s in serializers:
5255
self.serializers.setdefault(s.internal_type, []).append(s)
5356
self.deserializers = {d.serialized_id: d for d in deserializers}
@@ -77,7 +80,7 @@ def with_added_types(
7780
deserializers=[*self.deserializers.values(), *deserializers],
7881
)
7982

80-
def supported_internal_types(self) -> Tuple:
83+
def supported_internal_types(self) -> Tuple[_GateOrFrozenCircuitTypes, ...]:
8184
return tuple(self.serializers.keys())
8285

8386
def is_supported(self, op_tree: cirq.OP_TREE) -> bool:
@@ -194,8 +197,8 @@ def serialize_gate_op(
194197
if gate_type_mro in self.serializers:
195198
# Check each serializer in turn, if serializer proto returns
196199
# None, then skip.
197-
for serializer in self.serializers[gate_type_mro]:
198-
proto_msg = serializer.to_proto(
200+
for mro_serializer in self.serializers[gate_type_mro]:
201+
proto_msg = mro_serializer.to_proto(
199202
op,
200203
msg,
201204
arg_function_language=arg_function_language,

0 commit comments

Comments
 (0)