Skip to content

Fix decompose for controlled CZ gates with phase shift #7071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 26, 2025

Conversation

daxfohl
Copy link
Collaborator

@daxfohl daxfohl commented Feb 17, 2025

Fixes #6488.
Fixes #6517.

Fixes the way global phase is handled in decomposition of controlled CZ gates (which are an intermediate step for many decompositions). This PR removes the phase from the Z sub-gate where the existing code had erroneously placed it. Instead, it creates a new, independent controlled-global-phase op as part of the decomposition, with the correct control values, and adds that into the decomposition to compensate.

The PR also adds a handler to allow the controlled global phase op to be further decomposed into Rz's and Y's, via an equivalent diagonal gate, as desired. (decompose will do this automatically by default, as is checked by the should_decompose_to_target in tests).

Finally, it hardens the consistency check in the controlled-gate consistency tests, to ensure that the full decomposition of the gates has a consistent unitary. The existing consistency checks in cirq.testing only check single decomposition steps, which is why these bugs were missed. Several of the existing checks would have failed if they had done full decomposition.

Tests have been added to check this against different dimensions, different control values including non-product like xor, and with symbolic parameters. Minor bugs were found as part of these tests, in the reprs of global phase gates with symbolic parameters and controlled gates with single-qudit dimensions, and a missing qasm check for qid dimension, and these have been fixed here too.

@CirqBot CirqBot added the size: M 50< lines changed <250 label Feb 17, 2025
@daxfohl daxfohl marked this pull request as ready for review February 17, 2025 08:06
@daxfohl daxfohl requested review from vtomole and a team as code owners February 17, 2025 08:06
Copy link

codecov bot commented Feb 17, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 98.18%. Comparing base (ca6ceb3) to head (47d539c).
Report is 22 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #7071   +/-   ##
=======================================
  Coverage   98.18%   98.18%           
=======================================
  Files        1089     1089           
  Lines       95237    95267   +30     
=======================================
+ Hits        93508    93538   +30     
  Misses       1729     1729           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@daxfohl daxfohl marked this pull request as draft February 18, 2025 15:54
@daxfohl daxfohl marked this pull request as ready for review February 18, 2025 17:13
Copy link
Collaborator

@dstrain115 dstrain115 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Very nice to see this fixed!

@daxfohl daxfohl added this pull request to the merge queue Feb 26, 2025
Merged via the queue into quantumlib:main with commit 61548e0 Feb 26, 2025
38 checks passed
@daxfohl daxfohl deleted the fix-controlled-phase-decompose branch February 26, 2025 22:29
@pavoljuhas
Copy link
Collaborator

@daxfohl - starting with this PR I am getting the following pytest failure on my Debian-linux desktop -

check/pytest -n0 -vv \
  "cirq-core/cirq/ops/controlled_gate_test.py::test_controlled_gate_is_consistent[gate13-True]"
Output
...
================================================== FAILURES ==================================================
______________________________ test_controlled_gate_is_consistent[gate13-True] _______________________________

gate = cirq.ControlledGate(sub_gate=cirq.XXPowGate(exponent=0.25, global_shift=-0.5), control_values=cirq.ProductOfSums(((1,), (0, 1))),control_qid_shape=(2, 2))
should_decompose_to_target = True

    @pytest.mark.parametrize(
        'gate, should_decompose_to_target',
        [
            (cirq.X, True),
            (cirq.X**0.5, True),
            (cirq.rx(np.pi), True),
            (cirq.rx(np.pi / 2), True),
            (cirq.Z, True),
            (cirq.H, True),
            (cirq.CNOT, True),
            (cirq.SWAP, True),
            (cirq.CCZ, True),
            (cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ)), True),
            (GateUsingWorkspaceForApplyUnitary(), True),
            (GateAllocatingNewSpaceForResult(), True),
            (cirq.IdentityGate(qid_shape=(3, 4)), True),
            (
                cirq.ControlledGate(
                    cirq.XXPowGate(exponent=0.25, global_shift=-0.5),
                    num_controls=2,
                    control_values=(1, (1, 0)),
                ),
                True,
            ),
            (cirq.GlobalPhaseGate(-1), True),
            (cirq.GlobalPhaseGate(1j**0.7), True),
            (cirq.GlobalPhaseGate(sympy.Symbol("s")), False),
            (cirq.CZPowGate(exponent=1.2, global_shift=0.3), True),
            (cirq.CZPowGate(exponent=sympy.Symbol("s"), global_shift=0.3), False),
            # Single qudit gate with dimension 4.
            (cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)), False),
            (cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)), False),
            (cirq.XX ** sympy.Symbol("s"), True),
            (cirq.CZ ** sympy.Symbol("s"), True),
            # Non-trivial `cirq.ProductOfSum` controls.
            (C_01_10_11H, False),
            (C_xorH, False),
            (C0C_xorH, False),
        ],
    )
    def test_controlled_gate_is_consistent(gate: cirq.Gate, should_decompose_to_target):
>       _test_controlled_gate_is_consistent(gate, should_decompose_to_target)

cirq-core/cirq/ops/controlled_gate_test.py:428:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cirq-core/cirq/ops/controlled_gate_test.py:482: in _test_controlled_gate_is_consistent
    np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-1)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (<function assert_allclose.<locals>.compare at 0x7fc878b1c670>, array([[1.        +0.j        , 0.        +0.j        ...,
         1.80411242e-16+2.22044605e-16j,  3.88578059e-16+1.66533454e-16j,
        -6.53281482e-01+6.53281482e-01j]]))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0.1', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError:
E           Not equal to tolerance rtol=1e-07, atol=0.1
E
E           Mismatched elements: 40 / 1024 (3.91%)
E           Max absolute difference: 1.84775907
E           Max relative difference: 1.84775907
E            x: array([[1.     +0.j      , 0.     +0.j      , 0.     +0.j      , ...,
E                   0.     +0.j      , 0.     +0.j      , 0.     +0.j      ],
E                  [0.     +0.j      , 1.     +0.j      , 0.     +0.j      , ...,...
E            y: array([[-7.071068e-01+7.071068e-01j, -1.665335e-16-5.273559e-16j,
E                   -5.702532e-16+1.618554e-16j, ...,  0.000000e+00+0.000000e+00j,
E                    0.000000e+00+0.000000e+00j,  0.000000e+00+0.000000e+00j],...

../../../../arch/pyenv-root/versions/3.10.14/lib/python3.10/contextlib.py:79: AssertionError
========================================== short test summary info ===========================================
FAILED cirq-core/cirq/ops/controlled_gate_test.py::test_controlled_gate_is_consistent[gate13-True] - AssertionError:
Not equal to tolerance rtol=1e-07, atol=0.1

Mismatched elements: 40 / 1024 (3.91%)
Max absolute difference: 1.84775907
Max relative difference: 1.84775907
 x: array([[1.     +0.j      , 0.     +0.j      , 0.     +0.j      , ...,
        0.     +0.j      , 0.     +0.j      , 0.     +0.j      ],
       [0.     +0.j      , 1.     +0.j      , 0.     +0.j      , ...,...
 y: array([[-7.071068e-01+7.071068e-01j, -1.665335e-16-5.273559e-16j,
        -5.702532e-16+1.618554e-16j, ...,  0.000000e+00+0.000000e+00j,
         0.000000e+00+0.000000e+00j,  0.000000e+00+0.000000e+00j],...
============================================= 1 failed in 1.39s ==============================================

The failing test test_controlled_gate_is_consistent[gate13-True] is for the controlled XXPowGate here -

(
cirq.ControlledGate(
cirq.XXPowGate(exponent=0.25, global_shift=-0.5),
num_controls=2,
control_values=(1, (1, 0)),
),
True,
),
and it fails comparison of unitaries here. The differing elements in the unitaries have the same absolute values,
but their phase is consistently different by 0.75pi - (phases in the decomposed-circuit unitary are 0.75pi or 0.25pi, phases in the gate unitary are 0 or -0.5pi respectively).

Any idea where this might be coming from?

The strange thing is the test passes on a cloud-hosted machine with the same Python version and same libraries in virtual environment.

cc @kmlau

@daxfohl
Copy link
Collaborator Author

daxfohl commented Feb 28, 2025

@pavoljuhas IDK offhand. Does it decompose to the same thing on both machines? On my machine (ubuntu WSL) it decomposes to

(cirq.Z**-0.12500000000000003).on(cirq.LineQid(1, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.12500000000000003).on(cirq.LineQid(1, dimension=2))
cirq.Z(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(3, dimension=2))
cirq.Z(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**-0.12500000000000003).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Z**-0.12500000000000003).on(cirq.LineQid(1, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.12500000000000003).on(cirq.LineQid(1, dimension=2))
cirq.Z(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(4, dimension=2))
cirq.Z(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**-0.12500000000000003).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Z**0.06250000000000001).on(cirq.LineQid(1, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Z**0.125).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**-0.06250000000000001).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.06250000000000006).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Z**0.06250000000000006).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Z**1.8749999999999998).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.06250000000000001).on(cirq.LineQid(0, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Z**0.125).on(cirq.LineQid(3, dimension=2))
(cirq.Z**0.06250000000000001).on(cirq.LineQid(1, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**0.125).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**-0.06250000000000001).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.06250000000000006).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**0.06250000000000006).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**1.8749999999999998).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.06250000000000001).on(cirq.LineQid(0, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**0.125).on(cirq.LineQid(4, dimension=2))
(cirq.Z**-0.125).on(cirq.LineQid(3, dimension=2))
(cirq.Z**0.125).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(3, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**0.125).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(3, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**1.75).on(cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
cirq.XPowGate(global_shift=-0.25).on(cirq.LineQid(3, dimension=2))
cirq.T(cirq.LineQid(1, dimension=2))
cirq.T(cirq.LineQid(0, dimension=2))
cirq.T(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.T**-1).on(cirq.LineQid(0, dimension=2))
cirq.T(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.T**-1).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.T**-1).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
cirq.XPowGate(global_shift=-0.25).on(cirq.LineQid(3, dimension=2))
(cirq.Z**0.125).on(cirq.LineQid(3, dimension=2))
(cirq.Z**1.8750000000000002).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(3, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**1.8750000000000002).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(3, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
cirq.T(cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
cirq.XPowGate(global_shift=-0.25).on(cirq.LineQid(3, dimension=2))
cirq.T(cirq.LineQid(1, dimension=2))
cirq.T(cirq.LineQid(0, dimension=2))
cirq.T(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.T**-1).on(cirq.LineQid(0, dimension=2))
cirq.T(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.T**-1).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.T**-1).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
cirq.XPowGate(global_shift=-0.25).on(cirq.LineQid(3, dimension=2))
(cirq.Z**-0.06250000000000001).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.06250000000000006).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**0.06250000000000006).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**1.8749999999999998).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.06250000000000001).on(cirq.LineQid(1, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**1.9375000000000002).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**0.125).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**-0.06250000000000001).on(cirq.LineQid(0, dimension=2))
(cirq.Z**0.06250000000000006).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**0.06250000000000006).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Z**1.8749999999999998).on(cirq.LineQid(4, dimension=2))
cirq.global_phase_operation((0.9951847266721969-0.09801714032956063j))
cirq.Rz(rads=-0.0).on(cirq.LineQid(2, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(2, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
cirq.Rz(rads=-0.0).on(cirq.LineQid(1, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(2, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
cirq.Rz(rads=-0.1963495408493621).on(cirq.LineQid(1, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
cirq.Rz(rads=0.1963495408493621).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(2, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
cirq.Rz(rads=-0.0).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
cirq.Rz(rads=-0.0).on(cirq.LineQid(0, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(0, dimension=2))
cirq.CZ(cirq.LineQid(2, dimension=2), cirq.LineQid(0, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(0, dimension=2))
cirq.Rz(rads=-0.1963495408493621).on(cirq.LineQid(0, dimension=2))
(cirq.Z**0.12500000000000003).on(cirq.LineQid(1, dimension=2))
cirq.Z(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(3, dimension=2))
cirq.Z(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**-0.12500000000000003).on(cirq.LineQid(1, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.12500000000000003).on(cirq.LineQid(0, dimension=2))
cirq.Z(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(3, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(3, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(3, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(3, dimension=2))
cirq.Z(cirq.LineQid(3, dimension=2))
(cirq.Z**0.12500000000000003).on(cirq.LineQid(1, dimension=2))
cirq.Z(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(4, dimension=2))
cirq.Z(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**-0.12500000000000003).on(cirq.LineQid(1, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(1, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(1, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(1, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(1, dimension=2))
(cirq.Z**0.12500000000000003).on(cirq.LineQid(0, dimension=2))
cirq.Z(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**0.12500000000000003).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.5).on(cirq.LineQid(4, dimension=2))
cirq.CZ(cirq.LineQid(0, dimension=2), cirq.LineQid(4, dimension=2))
(cirq.Y**0.5).on(cirq.LineQid(4, dimension=2))
(cirq.Y**-0.12500000000000003).on(cirq.LineQid(4, dimension=2))
cirq.Z(cirq.LineQid(4, dimension=2))

@pavoljuhas
Copy link
Collaborator

pavoljuhas commented Feb 28, 2025

@daxfohl - I used this patch to save the decomposed circuit in test 13 just before the failing assertion

diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py
index 1a9f44e2..018cf04b 100644
--- a/cirq-core/cirq/ops/controlled_gate_test.py
+++ b/cirq-core/cirq/ops/controlled_gate_test.py
@@ -424,3 +424,3 @@ def test_unitary():
         (C0C_xorH, False),
-    ],
+    ][13:14],
 )
@@ -481,2 +481,3 @@ def _test_controlled_gate_is_consistent(
             circuit = cirq.Circuit(first_op, *decomposed)
+            cirq.to_json(circuit, "decomposed_circuit.json")
             np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-1)

And here are the outputs on my desktop fail-decomposed_circuit.json and on the cloud host pass-decomposed_circuit.json, both outputs are reproducible. The first difference is in the exponent of the first ZPowGate of -0.125 (fail) vs -0.12500000000000003 (pass). Not sure if that first difference is significant, but it shows 2 machines have a bit different numeric round-offs.

@daxfohl
Copy link
Collaborator Author

daxfohl commented Mar 3, 2025

@pavoljuhas in the passing ones I see

cirq.Y**-0.12500000000000003

get changed into

cirq.Z**0.5000000000000002
cirq.X**0.12500000000000003
cirq.S**-1

on the failing ones.

This indeed is equal up to global phase, but has a phase difference.

My hunch is this is occurring somewhere in decompose_multi_controlled_rotation, called from ControlledGate._decompose_with_context_. Maybe try putting an assertion right after that call, that checks whether the decomposition is correct. np.testing.assert_allclose(protocols.unitary(circuits.Circuit(decomposed_ops)), protocols.unitary(self)), or something like that. If that fails, then figure out why it's returning different gates on your machine.

@daxfohl
Copy link
Collaborator Author

daxfohl commented Mar 3, 2025

Oh, I think I see. decompose_multi_controlled_rotation can include MatrixGates in the decomposition. And then the decompose algorithm proceeds to further decompose the MatrixGate into standard gates. And we know that MatrixGate decomposition does not necessarily preserve global phase. So, that's why global phases are different.

It's odd that it would decompose into entirely different gates on different machines. Maybe it's the atol in single_qubit_matrix_to_pauli_rotations?

@pavoljuhas
Copy link
Collaborator

It's odd that it would decompose into entirely different gates on different machines. Maybe it's the atol in single_qubit_matrix_to_pauli_rotations?

Changing the default atol to a nonzero value made no difference on the failing test (same output, same values).

@daxfohl daxfohl restored the fix-controlled-phase-decompose branch April 5, 2025 00:11
@daxfohl daxfohl deleted the fix-controlled-phase-decompose branch April 5, 2025 00:13
BichengYing pushed a commit to BichengYing/Cirq that referenced this pull request Jun 20, 2025
* Fix decompose for controlled gates with phase shift

* Fix test

* Fix type check, int != complex

* Remove decomposition to Z

* Fix param name

* Fix controlled op qasm, reformat tests

* Fix test

* Fix test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
size: M 50< lines changed <250
Projects
None yet
4 participants