Skip to content

Improve stimcirq serialization #7192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 23 additions & 41 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
op_deserializer,
op_serializer,
serializer,
stimcirq_deserializer,
stimcirq_serializer,
tag_deserializer,
tag_serializer,
)
Expand All @@ -47,9 +49,6 @@
# CircuitSerializer is the dedicated serializer for the v2.5 format.
_SERIALIZER_NAME = 'v2_5'

# Package name for stimcirq
_STIMCIRQ_MODULE = "stimcirq"


class CircuitSerializer(serializer.Serializer):
"""A class for serializing and deserializing programs and operations.
Expand Down Expand Up @@ -93,6 +92,8 @@ def __init__(
self.op_deserializer = op_deserializer
self.tag_serializer = tag_serializer
self.tag_deserializer = tag_deserializer
self.stimcirq_serializer = stimcirq_serializer.StimCirqSerializer()
self.stimcirq_deserializer = stimcirq_deserializer.StimCirqDeserializer()

def serialize(
self, program: cirq.AbstractCircuit, msg: Optional[v2.program_pb2.Program] = None
Expand Down Expand Up @@ -160,6 +161,10 @@ def _serialize_circuit(
self.op_serializer.to_proto(
op, op_pb, constants=constants, raw_constants=raw_constants
)
elif self.stimcirq_serializer.can_serialize_operation(op):
self.stimcirq_serializer.to_proto(
op, op_pb, constants=constants, raw_constants=raw_constants
)
else:
self._serialize_gate_op(
op, op_pb, constants=constants, raw_constants=raw_constants
Expand All @@ -174,6 +179,10 @@ def _serialize_circuit(
self.op_serializer.to_proto(
op, op_pb, constants=constants, raw_constants=raw_constants
)
elif self.stimcirq_serializer.can_serialize_operation(op):
self.stimcirq_serializer.to_proto(
op, op_pb, constants=constants, raw_constants=raw_constants
)
else:
self._serialize_gate_op(
op, op_pb, constants=constants, raw_constants=raw_constants
Expand Down Expand Up @@ -277,30 +286,6 @@ def _serialize_gate_op(
arg_func_langs.float_arg_to_proto(
gate.q1_detune_mhz, out=msg.couplerpulsegate.q1_detune_mhz
)
elif getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) or getattr(
gate, "__module__", ""
).startswith(_STIMCIRQ_MODULE):
# Special handling for stimcirq objects, which can be both operations and gates.
stimcirq_obj = (
op if getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) else gate
)
if stimcirq_obj is not None and hasattr(stimcirq_obj, '_json_dict_'):
# All stimcirq gates currently have _json_dict_defined
msg.internalgate.name = type(stimcirq_obj).__name__
msg.internalgate.module = _STIMCIRQ_MODULE
if isinstance(stimcirq_obj, cirq.Gate):
msg.internalgate.num_qubits = stimcirq_obj.num_qubits()
else:
msg.internalgate.num_qubits = len(stimcirq_obj.qubits)

# Store json_dict objects in gate_args
for k, v in stimcirq_obj._json_dict_().items():
arg_func_langs.arg_to_proto(value=v, out=msg.internalgate.gate_args[k])
else:
# New stimcirq op without a json dict has been introduced
raise ValueError(
f'Cannot serialize stimcirq {op!r}:{type(gate)}'
) # pragma: no cover
else:
raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}')

Expand Down Expand Up @@ -438,6 +423,12 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
elif self.stimcirq_deserializer.can_deserialize_proto(constant.operation_value):
op_pb = self.stimcirq_deserializer.from_proto(
constant.operation_value,
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
else:
op_pb = self._deserialize_gate_op(
constant.operation_value,
Expand Down Expand Up @@ -517,6 +508,10 @@ def _deserialize_moment(
gate_op = self.op_deserializer.from_proto(
op, constants=constants, deserialized_constants=deserialized_constants
)
elif self.stimcirq_deserializer.can_deserialize_proto(op):
gate_op = self.stimcirq_deserializer.from_proto(
op, constants=constants, deserialized_constants=deserialized_constants
)
else:
gate_op = self._deserialize_gate_op(
op, constants=constants, deserialized_constants=deserialized_constants
Expand Down Expand Up @@ -718,20 +713,7 @@ def _deserialize_gate_op(
op = cirq.ResetChannel(dimension=dimensions)(*qubits)
elif which_gate_type == 'internalgate':
msg = operation_proto.internalgate
if msg.module == _STIMCIRQ_MODULE and msg.name in _stimcirq_json_resolvers():
# special handling for stimcirq
# Use JSON resolver to instantiate the object
kwargs = {}
for k, v in msg.gate_args.items():
arg = arg_func_langs.arg_from_proto(v)
if arg is not None:
kwargs[k] = arg
op = _stimcirq_json_resolvers()[msg.name](**kwargs)
if qubits:
op = op(*qubits)
else:
# all other internal gates
op = arg_func_langs.internal_gate_from_proto(msg)(*qubits)
op = arg_func_langs.internal_gate_from_proto(msg)(*qubits)
elif which_gate_type == 'couplerpulsegate':
gate = CouplerPulse(
hold_time=cirq.Duration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1162,12 +1162,41 @@ def test_reset_gate_with_no_dimension():
assert reset_circuit == cirq.Circuit(cirq.R(cirq.q(1, 2)))


def test_stimcirq_gates():
@pytest.mark.parametrize('use_constants_table', [True, False])
def test_stimcirq_gates(use_constants_table: bool):
stimcirq = pytest.importorskip("stimcirq")
serializer = cg.CircuitSerializer()
serializer = cg.CircuitSerializer(
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table,
)
q = cirq.q(1, 2)
q2 = cirq.q(2, 2)
c = cirq.Circuit(
cirq.Moment(
stimcirq.CumulativeObservableAnnotation(parity_keys=["m"], observable_index=123)
),
cirq.Moment(
stimcirq.MeasureAndOrResetGate(
measure=True,
reset=False,
basis='Z',
invert_measure=True,
key='mmm',
measure_flip_probability=0.125,
)(q2)
),
cirq.Moment(stimcirq.ShiftCoordsAnnotation([1.0, 2.0])),
cirq.Moment(
stimcirq.SweepPauli(stim_sweep_bit_index=2, cirq_sweep_symbol='t', pauli=cirq.X)(q)
),
cirq.Moment(
stimcirq.SweepPauli(stim_sweep_bit_index=3, cirq_sweep_symbol='y', pauli=cirq.Y)(q)
),
cirq.Moment(
stimcirq.SweepPauli(stim_sweep_bit_index=4, cirq_sweep_symbol='t', pauli=cirq.Z)(q)
),
cirq.Moment(stimcirq.TwoQubitAsymmetricDepolarizingChannel([0.05] * 15)(q, q2)),
cirq.Moment(stimcirq.CZSwapGate()(q, q2)),
cirq.Moment(stimcirq.CXSwapGate(inverted=True)(q, q2)),
cirq.Moment(cirq.measure(q, key="m")),
cirq.Moment(stimcirq.DetAnnotation(parity_keys=["m"])),
Expand Down
101 changes: 101 additions & 0 deletions cirq-google/cirq_google/serialization/stimcirq_deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Dict, List

import cirq
from cirq_google.api import v2
from cirq_google.serialization import arg_func_langs
from cirq_google.serialization.op_deserializer import OpDeserializer


@functools.cache
def _stimcirq_json_resolvers():
"""Retrieves stimcirq JSON resolvers if stimcirq is installed.
Returns an empty dict if not installed."""
try:
import stimcirq

return stimcirq.JSON_RESOLVERS_DICT
except ModuleNotFoundError: # pragma: no cover
return {} # pragma: no cover


class StimCirqDeserializer(OpDeserializer):
"""Describes how to serialize CircuitOperations."""

def can_deserialize_proto(self, proto: v2.program_pb2.Operation):
return (
proto.WhichOneof('gate_value') == 'internalgate'
and proto.internalgate.module == 'stimcirq'
)

def from_proto(
self,
proto: v2.program_pb2.Operation,
*,
constants: List[v2.program_pb2.Constant],
deserialized_constants: List[Any],
) -> cirq.Operation:
"""Turns a cirq_google Operation proto into a stimcirq object.

Args:
proto: The proto object to be deserialized.
constants: The list of Constant protos referenced by constant
table indices in `proto`. This list should already have been
parsed to produce 'deserialized_constants'.
deserialized_constants: The deserialized contents of `constants`.

Returns:
The deserialized stimcirq object

Raises:
ValueError: If stimcirq is not installed or the object is not recognized.
"""
resolvers = _stimcirq_json_resolvers()
cls_name = proto.internalgate.name

if cls_name not in resolvers:
raise ValueError(f"stimcirq object {proto} not recognized. (Is stimcirq installed?)")

# Resolve each of the serialized arguments
kwargs: Dict[str, Any] = {}
for k, v in proto.internalgate.gate_args.items():
if k == "pauli":
# Special Handling for pauli gate
pauli = v.arg_value.string_value
if pauli == "X":
kwargs[k] = cirq.X
elif pauli == "Y":
kwargs[k] = cirq.Y
elif pauli == "Z":
kwargs[k] = cirq.Z
else:
raise ValueError(f"Unknown stimcirq pauli Gate {v}")
continue

arg = arg_func_langs.arg_from_proto(v)
if arg is not None:
kwargs[k] = arg

# Instantiate the class from the stimcirq resolvers
op = resolvers[cls_name](**kwargs)

# If this operation has qubits, add them
qubits = [deserialized_constants[q] for q in proto.qubit_constant_index]
if qubits:
op = op(*qubits)

return op
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2025 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from cirq_google.api import v2
from cirq_google.serialization.stimcirq_deserializer import StimCirqDeserializer


def test_bad_stimcirq_op():
proto = v2.program_pb2.Operation()
proto.internalgate.module = 'stimcirq'
proto.internalgate.name = 'WolfgangPauli'

with pytest.raises(ValueError, match='not recognized'):
_ = StimCirqDeserializer().from_proto(proto, constants=[], deserialized_constants=[])


def test_bad_pauli_gate():
proto = v2.program_pb2.Operation()
proto.internalgate.module = 'stimcirq'
proto.internalgate.name = 'SweepPauli'
proto.internalgate.gate_args['stim_sweep_bit_index'].arg_value.float_value = 1.0
proto.internalgate.gate_args['cirq_sweep_symbol'].arg_value.string_value = 't'
proto.internalgate.gate_args['pauli'].arg_value.string_value = 'Q'

with pytest.raises(ValueError, match='pauli'):
_ = StimCirqDeserializer().from_proto(proto, constants=[], deserialized_constants=[])
Loading