Skip to content

Allow ability to plug in custom (de)serializers for cirq_google protos #7059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 66 additions & 23 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,23 @@ class CircuitSerializer(serializer.Serializer):
serialization of duplicate operations as entries in the constant table.
This flag will soon become the default and disappear as soon as
deserialization of this field is deployed.
op_serializer: Optional custom serializer for serializing unknown gates.
op_deserializer: Optional custom deserializer for deserializing unknown gates.
"""

def __init__(
self, USE_CONSTANTS_TABLE_FOR_MOMENTS=False, USE_CONSTANTS_TABLE_FOR_OPERATIONS=False
self,
USE_CONSTANTS_TABLE_FOR_MOMENTS=False,
USE_CONSTANTS_TABLE_FOR_OPERATIONS=False,
op_serializer: Optional[op_serializer.OpSerializer] = None,
op_deserializer: Optional[op_deserializer.OpDeserializer] = None,
):
"""Construct the circuit serializer object."""
super().__init__(gate_set_name=_SERIALIZER_NAME)
self.use_constants_table_for_moments = USE_CONSTANTS_TABLE_FOR_MOMENTS
self.use_constants_table_for_operations = USE_CONSTANTS_TABLE_FOR_OPERATIONS
self.op_serializer = op_serializer
self.op_deserializer = op_deserializer

def serialize(
self,
Expand Down Expand Up @@ -144,26 +152,44 @@ def _serialize_circuit(
moment_proto.operation_indices.append(op_index)
else:
op_pb = v2.program_pb2.Operation()
if self.op_serializer and self.op_serializer.can_serialize_operation(op):
self.op_serializer.to_proto(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
else:
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
constants.append(v2.program_pb2.Constant(operation_value=op_pb))
op_index = len(constants) - 1
raw_constants[op] = op_index
moment_proto.operation_indices.append(op_index)
else:
op_pb = moment_proto.operations.add()
if self.op_serializer and self.op_serializer.can_serialize_operation(op):
self.op_serializer.to_proto(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
else:
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
constants.append(v2.program_pb2.Constant(operation_value=op_pb))
op_index = len(constants) - 1
raw_constants[op] = op_index
moment_proto.operation_indices.append(op_index)
else:
op_pb = moment_proto.operations.add()
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)

if self.use_constants_table_for_moments:
# Add this moment to the constants table
Expand Down Expand Up @@ -469,14 +495,23 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
elif which_const == 'qubit':
deserialized_constants.append(v2.qubit_from_proto_id(constant.qubit.id))
elif which_const == 'operation_value':
deserialized_constants.append(
self._deserialize_gate_op(
if self.op_deserializer and self.op_deserializer.can_deserialize_proto(
constant.operation_value
):
op_pb = self.op_deserializer.from_proto(
constant.operation_value,
arg_function_language=arg_func_language,
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
)
else:
op_pb = self._deserialize_gate_op(
constant.operation_value,
arg_function_language=arg_func_language,
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
deserialized_constants.append(op_pb)
elif which_const == 'moment_value':
deserialized_constants.append(
self._deserialize_moment(
Expand Down Expand Up @@ -541,12 +576,20 @@ def _deserialize_moment(
) -> cirq.Moment:
moment_ops = []
for op in moment_proto.operations:
gate_op = self._deserialize_gate_op(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
if self.op_deserializer and self.op_deserializer.can_deserialize_proto(op):
gate_op = self.op_deserializer.from_proto(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
else:
gate_op = self._deserialize_gate_op(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
if op.tag_indices:
tags = [
deserialized_constants[tag_index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Any, Dict, List, Optional
import pytest

import numpy as np
Expand All @@ -25,6 +25,8 @@
import cirq_google as cg
from cirq_google.api import v2
from cirq_google.serialization.circuit_serializer import _SERIALIZER_NAME
from cirq_google.serialization.op_deserializer import OpDeserializer
from cirq_google.serialization.op_serializer import OpSerializer


class FakeDevice(cirq.Device):
Expand Down Expand Up @@ -856,6 +858,7 @@ def test_circuit_with_tag(tag):
assert nc[0].operations[0].tags == (tag,)


@pytest.mark.filterwarnings('ignore:Unknown tag msg=phase_match')
def test_unknown_tag_is_ignored():
class DingDongTag:
pass
Expand All @@ -866,6 +869,7 @@ class DingDongTag:
assert cirq.Circuit(cirq.X(cirq.q(0))) == nc


@pytest.mark.filterwarnings('ignore:Unrecognized Tag .*DingDongTag')
def test_unrecognized_tag_is_ignored():
op_tag = v2.program_pb2.Operation()
op_tag.xpowgate.exponent.float_value = 1.0
Expand Down Expand Up @@ -917,3 +921,90 @@ def test_circuit_with_units():
)
msg = cg.CIRCUIT_SERIALIZER.serialize(c)
assert c == cg.CIRCUIT_SERIALIZER.deserialize(msg)


class BingBongGate(cirq.Gate):

def __init__(self, param: float):
self.param = param

def _num_qubits_(self) -> int:
return 1


class BingBongSerializer(OpSerializer):
"""Describes how to serialize CircuitOperations."""

def can_serialize_operation(self, op):
return isinstance(op.gate, BingBongGate)

def to_proto(
self,
op: cirq.CircuitOperation,
msg: Optional[v2.program_pb2.CircuitOperation] = None,
*,
arg_function_language: Optional[str] = '',
constants: List[v2.program_pb2.Constant],
raw_constants: Dict[Any, int],
) -> v2.program_pb2.CircuitOperation:
assert isinstance(op.gate, BingBongGate)
if msg is None:
msg = v2.program_pb2.Operation() # pragma: nocover
msg.internalgate.name = 'bingbong'
msg.internalgate.module = 'test'
msg.internalgate.num_qubits = 1
msg.internalgate.gate_args['param'].arg_value.float_value = op.gate.param

for qubit in op.qubits:
if qubit not in raw_constants:
constants.append(
v2.program_pb2.Constant(
qubit=v2.program_pb2.Qubit(id=v2.qubit_to_proto_id(qubit))
)
)
raw_constants[qubit] = len(constants) - 1
msg.qubit_constant_index.append(raw_constants[qubit])
return msg


class BingBongDeserializer(OpDeserializer):
"""Describes how to serialize CircuitOperations."""

def can_deserialize_proto(self, proto):
return (
isinstance(proto, v2.program_pb2.Operation)
and proto.WhichOneof("gate_value") == "internalgate"
and proto.internalgate.name == 'bingbong'
and proto.internalgate.module == 'test'
)

def from_proto(
self,
proto: v2.program_pb2.Operation,
*,
arg_function_language: str = '',
constants: List[v2.program_pb2.Constant],
deserialized_constants: List[Any],
) -> cirq.Operation:
return BingBongGate(param=proto.internalgate.gate_args["param"].arg_value.float_value).on(
deserialized_constants[proto.qubit_constant_index[0]]
)


@pytest.mark.parametrize('use_constants_table', [True, False])
def test_custom_serializer(use_constants_table: bool):
c = cirq.Circuit(BingBongGate(param=2.5)(cirq.q(0, 0)))
serializer = cg.CircuitSerializer(
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table,
op_serializer=BingBongSerializer(),
op_deserializer=BingBongDeserializer(),
)
msg = serializer.serialize(c)
deserialized_circuit = serializer.deserialize(msg)
moment = deserialized_circuit[0]
assert len(moment) == 1
op = moment[cirq.q(0, 0)]
assert isinstance(op.gate, BingBongGate)
assert op.gate.param == 2.5
assert op.qubits == (cirq.q(0, 0),)
17 changes: 5 additions & 12 deletions cirq-google/cirq_google/serialization/op_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,12 @@ class OpDeserializer(abc.ABC):
"""Generic supertype for operation deserializers.

Each operation deserializer describes how to deserialize operation protos
with a particular `serialized_id` to a specific type of Cirq operation.
to a specific type of Cirq operation.
"""

@property
@abc.abstractmethod
def serialized_id(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update class docstring which refers to serializer_id.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Also warnings suppressed.

"""Returns the string identifier for the accepted serialized objects.

This ID denotes the serialization format this deserializer consumes. For
example, one of the common deserializers converts objects with the id
'xy' into PhasedXPowGates.
"""
def can_deserialize_proto(self, proto) -> bool:
"""Whether the given operation can be serialized by this serializer."""

@abc.abstractmethod
def from_proto(
Expand Down Expand Up @@ -66,9 +60,8 @@ def from_proto(
class CircuitOpDeserializer(OpDeserializer):
"""Describes how to serialize CircuitOperations."""

@property
def serialized_id(self):
return 'circuit'
def can_deserialize_proto(self, proto):
return isinstance(proto, v2.program_pb2.CircuitOperation) # pragma: nocover

def from_proto(
self,
Expand Down
47 changes: 4 additions & 43 deletions cirq-google/cirq_google/serialization/op_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
from typing import Any, Dict, List, Optional, Union
import numbers

import abc
Expand All @@ -23,9 +23,6 @@
from cirq_google.api import v2
from cirq_google.serialization.arg_func_langs import arg_to_proto

# Type for variables that are subclasses of ops.Gate.
Gate = TypeVar('Gate', bound=cirq.Gate)


class OpSerializer(abc.ABC):
"""Generic supertype for operation serializers.
Expand All @@ -35,25 +32,6 @@ class OpSerializer(abc.ABC):
may serialize to the same format.
"""

@property
@abc.abstractmethod
def internal_type(self) -> Type:
"""Returns the type that the operation contains.

For GateOperations, this is the gate type.
For CircuitOperations, this is FrozenCircuit.
"""

@property
@abc.abstractmethod
def serialized_id(self) -> str:
"""Returns the string identifier for the resulting serialized object.

This ID denotes the serialization format this serializer produces. For
example, one of the common serializers assigns the id 'xy' to XPowGates,
as they serialize into a format also used by YPowGates.
"""

@abc.abstractmethod
def to_proto(
self,
Expand All @@ -63,7 +41,7 @@ def to_proto(
arg_function_language: Optional[str] = '',
constants: List[v2.program_pb2.Constant],
raw_constants: Dict[Any, int],
) -> Optional[v2.program_pb2.CircuitOperation]:
) -> Optional[Union[v2.program_pb2.CircuitOperation, v2.program_pb2.Operation]]:
"""Converts op to proto using this serializer.

If self.can_serialize_operation(op) == false, this should return None.
Expand All @@ -83,33 +61,16 @@ def to_proto(
the returned object.
"""

@property
@abc.abstractmethod
def can_serialize_predicate(self) -> Callable[[cirq.Operation], bool]:
"""The method used to determine if this can serialize an operation.

Depending on the serializer, additional checks may be required.
"""

def can_serialize_operation(self, op: cirq.Operation) -> bool:
"""Whether the given operation can be serialized by this serializer."""
return self.can_serialize_predicate(op)


class CircuitOpSerializer(OpSerializer):
"""Describes how to serialize CircuitOperations."""

@property
def internal_type(self):
return cirq.FrozenCircuit

@property
def serialized_id(self):
return 'circuit'

@property
def can_serialize_predicate(self):
return lambda op: isinstance(op.untagged, cirq.CircuitOperation)
def can_serialize_operation(self, op: cirq.Operation):
return isinstance(op.untagged, cirq.CircuitOperation)

def to_proto(
self,
Expand Down
6 changes: 0 additions & 6 deletions cirq-google/cirq_google/serialization/op_serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ def default_circuit():
)


def test_circuit_op_serializer_properties():
serializer = cg.CircuitOpSerializer()
assert serializer.internal_type == cirq.FrozenCircuit
assert serializer.serialized_id == 'circuit'


def test_can_serialize_circuit_op():
serializer = cg.CircuitOpSerializer()
assert serializer.can_serialize_operation(cirq.CircuitOperation(default_circuit()))
Expand Down
Loading