Skip to content

Commit 43d033b

Browse files
authored
Add ability to add custom tag (de)serializers (#7169)
* Add ability to add custom tag (de)serializers - This adds custom tag serializers for serializing tags. - This allows greater flexibility and a much easier way to specify a serializer to serialize and deserialize internal tags. * Fix types. * Fix coverage * Address comments. * Fix unassigned variable.
1 parent 3ae4ddb commit 43d033b

File tree

4 files changed

+271
-31
lines changed

4 files changed

+271
-31
lines changed

cirq-google/cirq_google/serialization/circuit_serializer.py

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@
3232
)
3333
from cirq_google.ops.calibration_tag import CalibrationTag
3434
from cirq_google.experimental.ops import CouplerPulse
35-
from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs
35+
from cirq_google.serialization import (
36+
serializer,
37+
op_deserializer,
38+
op_serializer,
39+
arg_func_langs,
40+
tag_serializer,
41+
tag_deserializer,
42+
)
3643

3744
# The name used in program.proto to identify the serializer as CircuitSerializer.
3845
# "v2.5" refers to the most current v2.Program proto format.
@@ -64,6 +71,8 @@ class CircuitSerializer(serializer.Serializer):
6471
deserialization of this field is deployed.
6572
op_serializer: Optional custom serializer for serializing unknown gates.
6673
op_deserializer: Optional custom deserializer for deserializing unknown gates.
74+
tag_serializer: Optional custom serializer for serializing unknown tags.
75+
tag_deserializer: Optional custom deserializer for deserializing unknown tags.
6776
"""
6877

6978
def __init__(
@@ -72,13 +81,17 @@ def __init__(
7281
USE_CONSTANTS_TABLE_FOR_OPERATIONS=False,
7382
op_serializer: Optional[op_serializer.OpSerializer] = None,
7483
op_deserializer: Optional[op_deserializer.OpDeserializer] = None,
84+
tag_serializer: Optional[tag_serializer.TagSerializer] = None,
85+
tag_deserializer: Optional[tag_deserializer.TagDeserializer] = None,
7586
):
7687
"""Construct the circuit serializer object."""
7788
super().__init__(gate_set_name=_SERIALIZER_NAME)
7889
self.use_constants_table_for_moments = USE_CONSTANTS_TABLE_FOR_MOMENTS
7990
self.use_constants_table_for_operations = USE_CONSTANTS_TABLE_FOR_OPERATIONS
8091
self.op_serializer = op_serializer
8192
self.op_deserializer = op_deserializer
93+
self.tag_serializer = tag_serializer
94+
self.tag_deserializer = tag_deserializer
8295

8396
def serialize(
8497
self, program: cirq.AbstractCircuit, msg: Optional[v2.program_pb2.Program] = None
@@ -301,8 +314,8 @@ def _serialize_gate_op(
301314
msg.qubit_constant_index.append(raw_constants[qubit])
302315

303316
for tag in op.tags:
317+
constant = v2.program_pb2.Constant()
304318
if isinstance(tag, CalibrationTag):
305-
constant = v2.program_pb2.Constant()
306319
constant.string_value = tag.token
307320
if tag.token in raw_constants:
308321
msg.token_constant_index = raw_constants[tag.token]
@@ -317,16 +330,22 @@ def _serialize_gate_op(
317330
# TODO(dstrain): Remove this once we are deserializing tag indices everywhere.
318331
tag.to_proto(msg=msg.tags.add())
319332
if (tag_index := raw_constants.get(tag, None)) is None:
320-
constant = v2.program_pb2.Constant()
321-
tag_index = len(constants)
322-
if getattr(tag, 'to_proto', None) is not None:
333+
if self.tag_serializer and self.tag_serializer.can_serialize_tag(tag):
334+
self.tag_serializer.to_proto(
335+
tag,
336+
msg=constant.tag_value,
337+
constants=constants,
338+
raw_constants=raw_constants,
339+
)
340+
elif getattr(tag, 'to_proto', None) is not None:
323341
tag.to_proto(constant.tag_value) # type: ignore
324-
constants.append(constant)
325-
if raw_constants is not None:
326-
raw_constants[tag] = tag_index
327-
msg.tag_indices.append(tag_index)
328342
else:
329343
warnings.warn(f'Unrecognized Tag {tag}, not serializing.')
344+
if constant.WhichOneof('const_value'):
345+
constants.append(constant)
346+
if raw_constants is not None:
347+
raw_constants[tag] = len(constants) - 1
348+
msg.tag_indices.append(len(constants) - 1)
330349
else:
331350
msg.tag_indices.append(tag_index)
332351
return msg
@@ -434,7 +453,18 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
434453
)
435454
)
436455
elif which_const == 'tag_value':
437-
deserialized_constants.append(self._deserialize_tag(constant.tag_value))
456+
if self.tag_deserializer and self.tag_deserializer.can_deserialize_proto(
457+
constant.tag_value
458+
):
459+
deserialized_constants.append(
460+
self.tag_deserializer.from_proto(
461+
constant.tag_value,
462+
constants=proto.constants,
463+
deserialized_constants=deserialized_constants,
464+
)
465+
)
466+
else:
467+
deserialized_constants.append(self._deserialize_tag(constant.tag_value))
438468
else:
439469
msg = f'Unrecognized constant type {which_const}, ignoring.' # pragma: no cover
440470
warnings.warn(msg) # pragma: no cover
@@ -490,22 +520,7 @@ def _deserialize_moment(
490520
gate_op = self._deserialize_gate_op(
491521
op, constants=constants, deserialized_constants=deserialized_constants
492522
)
493-
if op.tag_indices:
494-
tags = [
495-
deserialized_constants[tag_index]
496-
for tag_index in op.tag_indices
497-
if deserialized_constants[tag_index] not in gate_op.tags
498-
and deserialized_constants[tag_index] is not None
499-
]
500-
else:
501-
tags = []
502-
for tag in op.tags:
503-
if (
504-
tag not in gate_op.tags
505-
and (new_tag := self._deserialize_tag(tag)) is not None
506-
):
507-
tags.append(new_tag)
508-
moment_ops.append(gate_op.with_tags(*tags))
523+
moment_ops.append(gate_op)
509524
for op in moment_proto.circuit_operations:
510525
moment_ops.append(
511526
self._deserialize_circuit_op(
@@ -768,7 +783,30 @@ def _deserialize_gate_op(
768783
elif which == 'token_value':
769784
op = op.with_tags(CalibrationTag(operation_proto.token_value))
770785

771-
return op
786+
# Add tags to op
787+
if operation_proto.tag_indices and deserialized_constants is not None:
788+
tags = [
789+
deserialized_constants[tag_index]
790+
for tag_index in operation_proto.tag_indices
791+
if deserialized_constants[tag_index] not in op.tags
792+
and deserialized_constants[tag_index] is not None
793+
]
794+
else:
795+
tags = []
796+
for tag in operation_proto.tags:
797+
if tag not in op.tags:
798+
if self.tag_deserializer and self.tag_deserializer.can_deserialize_proto(tag):
799+
tags.append(
800+
self.tag_deserializer.from_proto(
801+
tag,
802+
constants=constants or [],
803+
deserialized_constants=deserialized_constants or [],
804+
)
805+
)
806+
elif (new_tag := self._deserialize_tag(tag)) is not None:
807+
tags.append(new_tag)
808+
809+
return op.with_tags(*tags)
772810

773811
def _deserialize_circuit_op(
774812
self,

cirq-google/cirq_google/serialization/circuit_serializer_test.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Any, Dict, List, Optional
1616
import pytest
1717

18+
import attrs
1819
import numpy as np
1920
import sympy
2021
from google.protobuf import json_format
@@ -27,6 +28,8 @@
2728
from cirq_google.serialization.circuit_serializer import _SERIALIZER_NAME
2829
from cirq_google.serialization.op_deserializer import OpDeserializer
2930
from cirq_google.serialization.op_serializer import OpSerializer
31+
from cirq_google.serialization.tag_deserializer import TagDeserializer
32+
from cirq_google.serialization.tag_serializer import TagSerializer
3033

3134

3235
class FakeDevice(cirq.Device):
@@ -916,10 +919,10 @@ def test_backwards_compatibility_with_old_tags():
916919
),
917920
constants=[v2.program_pb2.Constant(qubit=v2.program_pb2.Qubit(id='1_1'))],
918921
)
919-
expected_circuit_no_tag = cirq.Circuit(
922+
expected_circuit = cirq.Circuit(
920923
cirq.X(cirq.GridQubit(1, 1)).with_tags(cg.ops.DynamicalDecouplingTag(protocol='X'))
921924
)
922-
assert cg.CIRCUIT_SERIALIZER.deserialize(circuit_proto) == expected_circuit_no_tag
925+
assert cg.CIRCUIT_SERIALIZER.deserialize(circuit_proto) == expected_circuit
923926

924927

925928
def test_circuit_with_units():
@@ -949,7 +952,7 @@ def can_serialize_operation(self, op):
949952

950953
def to_proto(
951954
self,
952-
op: cirq.CircuitOperation,
955+
op: cirq.Operation,
953956
msg: Optional[v2.program_pb2.CircuitOperation] = None,
954957
*,
955958
constants: List[v2.program_pb2.Constant],
@@ -1008,7 +1011,7 @@ def test_serdes_preserves_syc():
10081011

10091012

10101013
@pytest.mark.parametrize('use_constants_table', [True, False])
1011-
def test_custom_serializer(use_constants_table: bool):
1014+
def test_custom_op_serializer(use_constants_table: bool):
10121015
c = cirq.Circuit(BingBongGate(param=2.5)(cirq.q(0, 0)))
10131016
serializer = cg.CircuitSerializer(
10141017
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
@@ -1026,6 +1029,97 @@ def test_custom_serializer(use_constants_table: bool):
10261029
assert op.qubits == (cirq.q(0, 0),)
10271030

10281031

1032+
@attrs.frozen
1033+
class DiscountTag:
1034+
discount: float
1035+
1036+
1037+
class DiscountTagSerializer(TagSerializer):
1038+
"""Describes how to serialize DiscountTag."""
1039+
1040+
def can_serialize_tag(self, tag):
1041+
return isinstance(tag, DiscountTag)
1042+
1043+
def to_proto(
1044+
self,
1045+
tag: Any,
1046+
msg: Optional[v2.program_pb2.Tag] = None,
1047+
*,
1048+
constants: List[v2.program_pb2.Constant],
1049+
raw_constants: Dict[Any, int],
1050+
) -> v2.program_pb2.Tag:
1051+
assert isinstance(tag, DiscountTag)
1052+
if msg is None:
1053+
msg = v2.program_pb2.Tag() # pragma: nocover
1054+
msg.internal_tag.tag_name = 'Discount'
1055+
msg.internal_tag.tag_package = 'test'
1056+
msg.internal_tag.tag_args['discount'].arg_value.float_value = tag.discount
1057+
return msg
1058+
1059+
1060+
class DiscountTagDeserializer(TagDeserializer):
1061+
"""Describes how to serialize CircuitOperations."""
1062+
1063+
def can_deserialize_proto(self, proto):
1064+
return (
1065+
proto.WhichOneof("tag") == "internal_tag"
1066+
and proto.internal_tag.tag_name == 'Discount'
1067+
and proto.internal_tag.tag_package == 'test'
1068+
)
1069+
1070+
def from_proto(
1071+
self,
1072+
proto: v2.program_pb2.Operation,
1073+
*,
1074+
constants: List[v2.program_pb2.Constant],
1075+
deserialized_constants: List[Any],
1076+
) -> DiscountTag:
1077+
return DiscountTag(discount=proto.internal_tag.tag_args["discount"].arg_value.float_value)
1078+
1079+
1080+
@pytest.mark.parametrize('use_constants_table', [True, False])
1081+
def test_custom_tag_serializer(use_constants_table: bool):
1082+
c = cirq.Circuit(cirq.X(cirq.q(0, 0)).with_tags(DiscountTag(0.25)))
1083+
serializer = cg.CircuitSerializer(
1084+
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
1085+
USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table,
1086+
tag_serializer=DiscountTagSerializer(),
1087+
tag_deserializer=DiscountTagDeserializer(),
1088+
)
1089+
msg = serializer.serialize(c)
1090+
deserialized_circuit = serializer.deserialize(msg)
1091+
moment = deserialized_circuit[0]
1092+
assert len(moment) == 1
1093+
op = moment[cirq.q(0, 0)]
1094+
assert len(op.tags) == 1
1095+
assert isinstance(op.tags[0], DiscountTag)
1096+
assert op.tags[0].discount == 0.25
1097+
1098+
1099+
def test_custom_tag_serializer_with_tags_outside_constants():
1100+
op_tag = v2.program_pb2.Operation()
1101+
op_tag.xpowgate.exponent.float_value = 1.0
1102+
op_tag.qubit_constant_index.append(0)
1103+
tag = v2.program_pb2.Tag()
1104+
tag.internal_tag.tag_name = 'Discount'
1105+
tag.internal_tag.tag_package = 'test'
1106+
tag.internal_tag.tag_args['discount'].arg_value.float_value = 0.5
1107+
op_tag.tags.append(tag)
1108+
circuit_proto = v2.program_pb2.Program(
1109+
language=v2.program_pb2.Language(arg_function_language='exp', gate_set=_SERIALIZER_NAME),
1110+
circuit=v2.program_pb2.Circuit(
1111+
scheduling_strategy=v2.program_pb2.Circuit.MOMENT_BY_MOMENT,
1112+
moments=[v2.program_pb2.Moment(operations=[op_tag])],
1113+
),
1114+
constants=[v2.program_pb2.Constant(qubit=v2.program_pb2.Qubit(id='1_1'))],
1115+
)
1116+
expected_circuit = cirq.Circuit(cirq.X(cirq.GridQubit(1, 1)).with_tags(DiscountTag(0.50)))
1117+
serializer = cg.CircuitSerializer(
1118+
tag_serializer=DiscountTagSerializer(), tag_deserializer=DiscountTagDeserializer()
1119+
)
1120+
assert serializer.deserialize(circuit_proto) == expected_circuit
1121+
1122+
10291123
def test_reset_gate_with_improper_argument():
10301124
serializer = cg.CircuitSerializer()
10311125

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, List
16+
17+
import abc
18+
19+
from cirq_google.api import v2
20+
21+
22+
class TagDeserializer(abc.ABC):
23+
"""Generic supertype for tag deserializers.
24+
25+
Each tag deserializer describes how to deserialize a specific
26+
set of tag protos.
27+
"""
28+
29+
@abc.abstractmethod
30+
def can_deserialize_proto(self, proto: v2.program_pb2.Tag) -> bool:
31+
"""Whether the given tag can be serialized by this serializer."""
32+
33+
@abc.abstractmethod
34+
def from_proto(
35+
self,
36+
proto: v2.program_pb2.Tag,
37+
*,
38+
constants: List[v2.program_pb2.Constant],
39+
deserialized_constants: List[Any],
40+
) -> Any:
41+
"""Converts a proto-formatted operation into a Cirq operation.
42+
43+
Args:
44+
proto: The proto object to be deserialized.
45+
constants: The list of Constant protos referenced by constant
46+
table indices in `proto`.
47+
deserialized_constants: The deserialized contents of `constants`.
48+
49+
Returns:
50+
The deserialized operation represented by `proto`.
51+
"""

0 commit comments

Comments
 (0)