32
32
)
33
33
from cirq_google .ops .calibration_tag import CalibrationTag
34
34
from cirq_google .experimental .ops import CouplerPulse
35
- from cirq_google .serialization import serializer , op_deserializer , op_serializer , arg_func_langs
35
+ from cirq_google .serialization import (
36
+ serializer ,
37
+ op_deserializer ,
38
+ op_serializer ,
39
+ arg_func_langs ,
40
+ tag_serializer ,
41
+ tag_deserializer ,
42
+ )
36
43
37
44
# The name used in program.proto to identify the serializer as CircuitSerializer.
38
45
# "v2.5" refers to the most current v2.Program proto format.
@@ -64,6 +71,8 @@ class CircuitSerializer(serializer.Serializer):
64
71
deserialization of this field is deployed.
65
72
op_serializer: Optional custom serializer for serializing unknown gates.
66
73
op_deserializer: Optional custom deserializer for deserializing unknown gates.
74
+ tag_serializer: Optional custom serializer for serializing unknown tags.
75
+ tag_deserializer: Optional custom deserializer for deserializing unknown tags.
67
76
"""
68
77
69
78
def __init__ (
@@ -72,13 +81,17 @@ def __init__(
72
81
USE_CONSTANTS_TABLE_FOR_OPERATIONS = False ,
73
82
op_serializer : Optional [op_serializer .OpSerializer ] = None ,
74
83
op_deserializer : Optional [op_deserializer .OpDeserializer ] = None ,
84
+ tag_serializer : Optional [tag_serializer .TagSerializer ] = None ,
85
+ tag_deserializer : Optional [tag_deserializer .TagDeserializer ] = None ,
75
86
):
76
87
"""Construct the circuit serializer object."""
77
88
super ().__init__ (gate_set_name = _SERIALIZER_NAME )
78
89
self .use_constants_table_for_moments = USE_CONSTANTS_TABLE_FOR_MOMENTS
79
90
self .use_constants_table_for_operations = USE_CONSTANTS_TABLE_FOR_OPERATIONS
80
91
self .op_serializer = op_serializer
81
92
self .op_deserializer = op_deserializer
93
+ self .tag_serializer = tag_serializer
94
+ self .tag_deserializer = tag_deserializer
82
95
83
96
def serialize (
84
97
self , program : cirq .AbstractCircuit , msg : Optional [v2 .program_pb2 .Program ] = None
@@ -301,8 +314,8 @@ def _serialize_gate_op(
301
314
msg .qubit_constant_index .append (raw_constants [qubit ])
302
315
303
316
for tag in op .tags :
317
+ constant = v2 .program_pb2 .Constant ()
304
318
if isinstance (tag , CalibrationTag ):
305
- constant = v2 .program_pb2 .Constant ()
306
319
constant .string_value = tag .token
307
320
if tag .token in raw_constants :
308
321
msg .token_constant_index = raw_constants [tag .token ]
@@ -317,16 +330,22 @@ def _serialize_gate_op(
317
330
# TODO(dstrain): Remove this once we are deserializing tag indices everywhere.
318
331
tag .to_proto (msg = msg .tags .add ())
319
332
if (tag_index := raw_constants .get (tag , None )) is None :
320
- constant = v2 .program_pb2 .Constant ()
321
- tag_index = len (constants )
322
- if getattr (tag , 'to_proto' , None ) is not None :
333
+ if self .tag_serializer and self .tag_serializer .can_serialize_tag (tag ):
334
+ self .tag_serializer .to_proto (
335
+ tag ,
336
+ msg = constant .tag_value ,
337
+ constants = constants ,
338
+ raw_constants = raw_constants ,
339
+ )
340
+ elif getattr (tag , 'to_proto' , None ) is not None :
323
341
tag .to_proto (constant .tag_value ) # type: ignore
324
- constants .append (constant )
325
- if raw_constants is not None :
326
- raw_constants [tag ] = tag_index
327
- msg .tag_indices .append (tag_index )
328
342
else :
329
343
warnings .warn (f'Unrecognized Tag { tag } , not serializing.' )
344
+ if constant .WhichOneof ('const_value' ):
345
+ constants .append (constant )
346
+ if raw_constants is not None :
347
+ raw_constants [tag ] = len (constants ) - 1
348
+ msg .tag_indices .append (len (constants ) - 1 )
330
349
else :
331
350
msg .tag_indices .append (tag_index )
332
351
return msg
@@ -434,7 +453,18 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
434
453
)
435
454
)
436
455
elif which_const == 'tag_value' :
437
- deserialized_constants .append (self ._deserialize_tag (constant .tag_value ))
456
+ if self .tag_deserializer and self .tag_deserializer .can_deserialize_proto (
457
+ constant .tag_value
458
+ ):
459
+ deserialized_constants .append (
460
+ self .tag_deserializer .from_proto (
461
+ constant .tag_value ,
462
+ constants = proto .constants ,
463
+ deserialized_constants = deserialized_constants ,
464
+ )
465
+ )
466
+ else :
467
+ deserialized_constants .append (self ._deserialize_tag (constant .tag_value ))
438
468
else :
439
469
msg = f'Unrecognized constant type { which_const } , ignoring.' # pragma: no cover
440
470
warnings .warn (msg ) # pragma: no cover
@@ -490,22 +520,7 @@ def _deserialize_moment(
490
520
gate_op = self ._deserialize_gate_op (
491
521
op , constants = constants , deserialized_constants = deserialized_constants
492
522
)
493
- if op .tag_indices :
494
- tags = [
495
- deserialized_constants [tag_index ]
496
- for tag_index in op .tag_indices
497
- if deserialized_constants [tag_index ] not in gate_op .tags
498
- and deserialized_constants [tag_index ] is not None
499
- ]
500
- else :
501
- tags = []
502
- for tag in op .tags :
503
- if (
504
- tag not in gate_op .tags
505
- and (new_tag := self ._deserialize_tag (tag )) is not None
506
- ):
507
- tags .append (new_tag )
508
- moment_ops .append (gate_op .with_tags (* tags ))
523
+ moment_ops .append (gate_op )
509
524
for op in moment_proto .circuit_operations :
510
525
moment_ops .append (
511
526
self ._deserialize_circuit_op (
@@ -768,7 +783,30 @@ def _deserialize_gate_op(
768
783
elif which == 'token_value' :
769
784
op = op .with_tags (CalibrationTag (operation_proto .token_value ))
770
785
771
- return op
786
+ # Add tags to op
787
+ if operation_proto .tag_indices and deserialized_constants is not None :
788
+ tags = [
789
+ deserialized_constants [tag_index ]
790
+ for tag_index in operation_proto .tag_indices
791
+ if deserialized_constants [tag_index ] not in op .tags
792
+ and deserialized_constants [tag_index ] is not None
793
+ ]
794
+ else :
795
+ tags = []
796
+ for tag in operation_proto .tags :
797
+ if tag not in op .tags :
798
+ if self .tag_deserializer and self .tag_deserializer .can_deserialize_proto (tag ):
799
+ tags .append (
800
+ self .tag_deserializer .from_proto (
801
+ tag ,
802
+ constants = constants or [],
803
+ deserialized_constants = deserialized_constants or [],
804
+ )
805
+ )
806
+ elif (new_tag := self ._deserialize_tag (tag )) is not None :
807
+ tags .append (new_tag )
808
+
809
+ return op .with_tags (* tags )
772
810
773
811
def _deserialize_circuit_op (
774
812
self ,
0 commit comments