Skip to content

Add json serialization to diagonal gates #5356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _symmetricalqidpair(qids):
'CXPowGate': cirq.CXPowGate,
'CZPowGate': cirq.CZPowGate,
'CZTargetGateset': cirq.CZTargetGateset,
'DiagonalGate': cirq.DiagonalGate,
'DensePauliString': cirq.DensePauliString,
'DepolarizingChannel': cirq.DepolarizingChannel,
'DeviceMetadata': cirq.DeviceMetadata,
Expand Down Expand Up @@ -165,7 +166,6 @@ def _symmetricalqidpair(qids):
'QuantumFourierTransformGate': cirq.QuantumFourierTransformGate,
'QubitPermutationGate': cirq.QubitPermutationGate,
'RandomGateChannel': cirq.RandomGateChannel,
'TensoredConfusionMatrices': cirq.TensoredConfusionMatrices,
'RepetitionsStoppingCriteria': cirq.work.RepetitionsStoppingCriteria,
'ResetChannel': cirq.ResetChannel,
'Result': cirq.ResultDict, # Keep support for Cirq < 0.14.
Expand All @@ -182,8 +182,11 @@ def _symmetricalqidpair(qids):
'SwapPowGate': cirq.SwapPowGate,
'SympyCondition': cirq.SympyCondition,
'TaggedOperation': cirq.TaggedOperation,
'TensoredConfusionMatrices': cirq.TensoredConfusionMatrices,
'TiltedSquareLattice': cirq.TiltedSquareLattice,
'ThreeQubitDiagonalGate': cirq.ThreeQubitDiagonalGate,
'TrialResult': cirq.ResultDict, # keep support for Cirq < 0.11.
'TwoQubitDiagonalGate': cirq.TwoQubitDiagonalGate,
'TwoQubitGateTabulation': cirq.TwoQubitGateTabulation,
'_UnconstrainedDevice': cirq.devices.unconstrained_device._UnconstrainedDevice,
'VarianceStoppingCriteria': cirq.work.VarianceStoppingCriteria,
Expand Down
20 changes: 19 additions & 1 deletion cirq-core/cirq/ops/diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@
passed as a list.
"""

from typing import AbstractSet, Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import (
AbstractSet,
Any,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy as np
import sympy
Expand Down Expand Up @@ -83,6 +94,10 @@ def __init__(self, diag_angles_radians: Sequence['cirq.TParamVal']) -> None:
"""
self._diag_angles_radians: Tuple['cirq.TParamVal', ...] = tuple(diag_angles_radians)

@property
def diag_angles_radians(self) -> Tuple['cirq.TParamVal', ...]:
return self._diag_angles_radians

def _num_qubits_(self):
return int(np.log2(len(self._diag_angles_radians)))

Expand Down Expand Up @@ -194,6 +209,9 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits))
return decomposed_circ

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, attribute_names=["diag_angles_radians"])

def __repr__(self) -> str:
return 'cirq.DiagonalGate([{}])'.format(
','.join(proper_repr(angle) for angle in self._diag_angles_radians)
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/ops/diagonal_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def test_consistent_protocols(gate):
cirq.testing.assert_implements_consistent_protocols(gate)


def test_property():
assert cirq.DiagonalGate([2, 3, 5, 7]).diag_angles_radians == (2, 3, 5, 7)


@pytest.mark.parametrize('n', [1, 2, 3, 4, 5, 6, 7, 8, 9])
def test_decomposition_unitary(n):
diagonal_angles = np.random.randn(2**n)
Expand Down
12 changes: 10 additions & 2 deletions cirq-core/cirq/ops/three_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AbstractSet,
Any,
Collection,
Dict,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -221,7 +222,7 @@ class ThreeQubitDiagonalGate(raw_types.Gate):
elements are all phases.
"""

def __init__(self, diag_angles_radians: List[value.TParamVal]) -> None:
def __init__(self, diag_angles_radians: Sequence[value.TParamVal]) -> None:
r"""A three qubit gate with only diagonal elements.

This gate's off-diagonal elements are zero and its on diagonal
Expand All @@ -232,7 +233,11 @@ def __init__(self, diag_angles_radians: List[value.TParamVal]) -> None:
If these values are $(x_0, x_1, \ldots , x_7)$ then the unitary
has diagonal values $(e^{i x_0}, e^{i x_1}, \ldots, e^{i x_7})$.
"""
self._diag_angles_radians: List[value.TParamVal] = diag_angles_radians
self._diag_angles_radians: Tuple[value.TParamVal, ...] = tuple(diag_angles_radians)

@property
def diag_angles_radians(self) -> Tuple[value.TParamVal, ...]:
return self._diag_angles_radians

def _is_parameterized_(self) -> bool:
return any(protocols.is_parameterized(angle) for angle in self._diag_angles_radians)
Expand Down Expand Up @@ -367,6 +372,9 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
}
)

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, attribute_names=["diag_angles_radians"])

def __repr__(self) -> str:
return 'cirq.ThreeQubitDiagonalGate([{}])'.format(
','.join(proper_repr(angle) for angle in self._diag_angles_radians)
Expand Down
6 changes: 6 additions & 0 deletions cirq-core/cirq/ops/three_qubit_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ def test_decomposition_cost(op: cirq.Operation, max_two_cost: int):
assert two_cost == max_two_cost


def test_diagonal_gate_property():
assert cirq.ThreeQubitDiagonalGate([2, 3, 5, 7, 0, 0, 0, 1]).diag_angles_radians == (
(2, 3, 5, 7, 0, 0, 0, 1)
)


@pytest.mark.parametrize(
'gate',
[cirq.CCX, cirq.CSWAP, cirq.CCZ, cirq.ThreeQubitDiagonalGate([2, 3, 5, 7, 11, 13, 17, 19])],
Expand Down
9 changes: 8 additions & 1 deletion cirq-core/cirq/ops/two_qubit_diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
passed as a list.
"""

from typing import AbstractSet, Any, Tuple, Optional, Sequence, TYPE_CHECKING
from typing import AbstractSet, Any, Dict, Tuple, Optional, Sequence, TYPE_CHECKING
import numpy as np
import sympy

Expand Down Expand Up @@ -62,6 +62,10 @@ def __init__(self, diag_angles_radians: Sequence[value.TParamVal]) -> None:
"""
self._diag_angles_radians: Tuple[value.TParamVal, ...] = tuple(diag_angles_radians)

@property
def diag_angles_radians(self) -> Tuple[value.TParamVal, ...]:
return self._diag_angles_radians

def _num_qubits_(self) -> int:
return 2

Expand Down Expand Up @@ -134,6 +138,9 @@ def __repr__(self) -> str:
','.join(proper_repr(angle) for angle in self._diag_angles_radians)
)

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, attribute_names=["diag_angles_radians"])

def _quil_(
self, qubits: Tuple['cirq.Qid', ...], formatter: 'cirq.QuilFormatter'
) -> Optional[str]:
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/ops/two_qubit_diagonal_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_consistent_protocols(gate):
cirq.testing.assert_implements_consistent_protocols(gate)


def test_property():
assert cirq.TwoQubitDiagonalGate([2, 3, 5, 7]).diag_angles_radians == (2, 3, 5, 7)


def test_parameterized_decompose():
angles = sympy.symbols('x0, x1, x2, x3')
parameterized_op = cirq.TwoQubitDiagonalGate(angles).on(*cirq.LineQubit.range(2))
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/DiagonalGate.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"cirq_type": "DiagonalGate",
"diag_angles_radians": [
0.0,
1.0,
-1.0,
0.0
]
}
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/DiagonalGate.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.DiagonalGate(diag_angles_radians=[0.0, 1.0, -1.0, 0.0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's strange that the repr turns the values to a list instead of keeping it a tuple, but I see that's how it's always been and isn't from this PR.

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"cirq_type": "ThreeQubitDiagonalGate",
"diag_angles_radians": [
0.0,
1.0,
-1.0,
0.0,
0.5,
0.5,
0.5,
0.5
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.ThreeQubitDiagonalGate(diag_angles_radians=[0.0, 1.0, -1.0, 0.0, 0.5, 0.5, 0.5, 0.5])
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"cirq_type": "TwoQubitDiagonalGate",
"diag_angles_radians": [
0.0,
1.0,
-1.0,
0.0
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.TwoQubitDiagonalGate(diag_angles_radians=[0.0, 1.0, -1.0, 0.0])
3 changes: 0 additions & 3 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
'LinearCombinationOfOperations',
'Linspace',
'ListSweep',
'DiagonalGate',
'NeutralAtomDevice',
'PauliInteractionGate',
'PauliSumCollector',
Expand All @@ -65,9 +64,7 @@
'SparseSimulatorStep',
'StateVectorMixin',
'TextDiagramDrawer',
'ThreeQubitDiagonalGate',
'Timestamp',
'TwoQubitDiagonalGate',
'TwoQubitGateTabulationResult',
'UnitSweep',
'StateVectorSimulatorState',
Expand Down