Skip to content

Fix decompose for controlled CZ gates with phase shift #7071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@

from cirq import protocols, value, _import
from cirq.ops import (
raw_types,
control_values as cv,
controlled_operation as cop,
op_tree,
diagonal_gate as dg,
global_phase_op as gp,
matrix_gates,
control_values as cv,
op_tree,
raw_types,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -157,13 +159,13 @@ def _decompose_(
def _decompose_with_context_(
self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
control_qubits = list(qubits[: self.num_controls()])
if (
protocols.has_unitary(self.sub_gate)
and protocols.num_qubits(self.sub_gate) == 1
and self._qid_shape_() == (2,) * len(self._qid_shape_())
and isinstance(self.control_values, cv.ProductOfSums)
):
control_qubits = list(qubits[: self.num_controls()])
invert_ops: List['cirq.Operation'] = []
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
Expand All @@ -174,11 +176,21 @@ def _decompose_with_context_(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops

if isinstance(self.sub_gate, gp.GlobalPhaseGate):
# A controlled global phase is a diagonal gate (Z in the simplest case), where each
# active control set equal to the phase angle.
shape = self.control_qid_shape
if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}:
return NotImplemented
angle = np.angle(complex(self.sub_gate.coefficient))
if shape == (2,):
return common_gates.Z(*qubits) ** (angle / np.pi)
radians = np.zeros(shape=shape)
for hot in self.control_values.expand():
radians[hot] = angle
return dg.DiagonalGate(list(radians.flatten())).on(*qubits)
if isinstance(self.sub_gate, common_gates.CZPowGate):
z_sub_gate = common_gates.ZPowGate(
exponent=self.sub_gate.exponent, global_shift=self.sub_gate.global_shift
)
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent)
num_controls = self.num_controls() + 1
control_values = self.control_values & cv.ProductOfSums(((1,),))
control_qid_shape = self.control_qid_shape + (2,)
Expand All @@ -197,9 +209,18 @@ def _decompose_with_context_(
)
)
if self != controlled_z:
return protocols.decompose_once_with_qubits(
controlled_z, qubits, NotImplemented, context=context
)
result = controlled_z.on(*qubits)
if self.sub_gate.global_shift == 0:
return result
# Reconstruct the controlled global shift of the subgate.
total_shift = self.sub_gate.exponent * self.sub_gate.global_shift
phase_gate = gp.GlobalPhaseGate(1j ** (2 * total_shift))
controlled_phase_op = phase_gate.controlled(
num_controls=self.num_controls(),
control_values=self.control_values,
control_qid_shape=self.control_qid_shape,
).on(*control_qubits)
return [result, controlled_phase_op]

if isinstance(self.sub_gate, matrix_gates.MatrixGate):
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
Expand Down Expand Up @@ -328,7 +349,7 @@ def __str__(self) -> str:
return str(self.control_values) + str(self.sub_gate)

def __repr__(self) -> str:
if self.num_controls() == 1 and self.control_values.is_trivial:
if self.control_qid_shape == [2] and self.control_values.is_trivial:
return f'cirq.ControlledGate(sub_gate={self.sub_gate!r})'

if self.control_values.is_trivial and set(self.control_qid_shape) == {2}:
Expand Down
53 changes: 51 additions & 2 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from types import NotImplementedType
from typing import Union, Tuple, cast
from typing import Any, cast, Optional, Sequence, Tuple, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -408,6 +408,11 @@ def test_unitary():
),
True,
),
(cirq.GlobalPhaseGate(-1), True),
(cirq.GlobalPhaseGate(1j**0.7), True),
(cirq.GlobalPhaseGate(sympy.Symbol("s")), False),
(cirq.CZPowGate(exponent=1.2, global_shift=0.3), True),
(cirq.CZPowGate(exponent=sympy.Symbol("s"), global_shift=0.3), False),
# Single qudit gate with dimension 4.
(cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)), False),
(cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)), False),
Expand All @@ -420,11 +425,55 @@ def test_unitary():
],
)
def test_controlled_gate_is_consistent(gate: cirq.Gate, should_decompose_to_target):
cgate = cirq.ControlledGate(gate)
_test_controlled_gate_is_consistent(gate, should_decompose_to_target)


@pytest.mark.parametrize(
'gate, control_qid_shape, control_values, should_decompose_to_target',
[
(cirq.GlobalPhaseGate(1j**0.7), [2, 2], xor_control_values, True),
(cirq.GlobalPhaseGate(1j**0.7), [3], None, False),
(cirq.GlobalPhaseGate(1j**0.7), [3, 4], xor_control_values, False),
(cirq.CZPowGate(exponent=1.2, global_shift=0.3), [2, 2], None, True),
(cirq.CZPowGate(exponent=1.2, global_shift=0.3), [2, 2], xor_control_values, False),
(cirq.CZPowGate(exponent=1.2, global_shift=0.3), [3], None, False),
(cirq.CZPowGate(exponent=1.2, global_shift=0.3), [3, 4], xor_control_values, False),
],
)
def test_nontrivial_controlled_gate_is_consistent(
gate: cirq.Gate,
control_qid_shape: Sequence[int],
control_values: Any,
should_decompose_to_target: bool,
):
_test_controlled_gate_is_consistent(
gate, should_decompose_to_target, control_qid_shape, control_values
)


def _test_controlled_gate_is_consistent(
gate: cirq.Gate,
should_decompose_to_target: bool,
control_qid_shape: Optional[Sequence[int]] = None,
control_values: Any = None,
):
cgate = cirq.ControlledGate(
gate, control_qid_shape=control_qid_shape, control_values=control_values
)
cirq.testing.assert_implements_consistent_protocols(cgate)
cirq.testing.assert_decompose_ends_at_default_gateset(
cgate, ignore_known_gates=not should_decompose_to_target
)
# The above only decompose once, which doesn't check that the sub-gate's phase is handled.
# We need to check full decomposition here.
if not cirq.is_parameterized(gate):
shape = cirq.qid_shape(cgate)
qids = cirq.LineQid.for_qid_shape(shape)
decomposed = cirq.decompose(cgate.on(*qids))
if len(decomposed) < 3000: # CCCCCZ rounding error explodes
first_op = cirq.IdentityGate(qid_shape=shape).on(*qids) # To ensure same qid order
circuit = cirq.Circuit(first_op, *decomposed)
np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-8)


def test_pow_inverse():
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/ops/global_phase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import cirq
from cirq import value, protocols
from cirq._compat import proper_repr
from cirq.ops import raw_types, controlled_gate, control_values as cv


Expand Down Expand Up @@ -68,10 +69,10 @@ def __str__(self) -> str:
return str(self.coefficient)

def __repr__(self) -> str:
return f'cirq.GlobalPhaseGate({self.coefficient!r})'
return f'cirq.GlobalPhaseGate({proper_repr(self.coefficient)})'

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
return f'cirq.global_phase_operation({self.coefficient!r})'
return f'cirq.global_phase_operation({proper_repr(self.coefficient)})'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['coefficient'])
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/global_phase_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_init():


def test_protocols():
for p in [1, 1j, -1]:
for p in [1, 1j, -1, sympy.Symbol('s')]:
cirq.testing.assert_implements_consistent_protocols(cirq.global_phase_operation(p))

np.testing.assert_allclose(
Expand Down