From 1e8f898a40b4b135e9170118d6511d985740ed9d Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 16 Mar 2022 20:48:03 -0700 Subject: [PATCH 1/3] Add default decomposition for cirq.QubitPermutationGate in terms of adjacent swaps --- cirq-core/cirq/ops/permutation_gate.py | 11 ++++++++++- cirq-core/cirq/ops/permutation_gate_test.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/ops/permutation_gate.py b/cirq-core/cirq/ops/permutation_gate.py index e3fb9bfff02..d94b42a203f 100644 --- a/cirq-core/cirq/ops/permutation_gate.py +++ b/cirq-core/cirq/ops/permutation_gate.py @@ -16,7 +16,7 @@ from cirq import protocols, value from cirq._compat import deprecated -from cirq.ops import raw_types +from cirq.ops import raw_types, swap_gates if TYPE_CHECKING: import cirq @@ -74,6 +74,15 @@ def num_qubits(self): def _has_unitary_(self): return True + def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': + qubit_ids = [*range(len(qubits))] + for i in range(len(qubits)): + q_i = qubit_ids.index(self._permutation.index(i)) + for j in range(q_i, i, -1): + yield swap_gates.SWAP(qubits[j], qubits[j - 1]) + qubit_ids[j], qubit_ids[j - 1] = qubit_ids[j - 1], qubit_ids[j] + assert self._permutation[qubit_ids[i]] == i + def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'): # Compute the permutation index list. permuted_axes = list(range(len(args.target_tensor.shape))) diff --git a/cirq-core/cirq/ops/permutation_gate_test.py b/cirq-core/cirq/ops/permutation_gate_test.py index ed0b27c1d3b..7b74992f2d7 100644 --- a/cirq-core/cirq/ops/permutation_gate_test.py +++ b/cirq-core/cirq/ops/permutation_gate_test.py @@ -15,6 +15,7 @@ import pytest import cirq +import numpy as np from cirq.ops import QubitPermutationGate @@ -30,8 +31,12 @@ def test_permutation_gate_repr(): cirq.testing.assert_equivalent_repr(QubitPermutationGate([0, 1])) -def test_permutation_gate_consistent_protocols(): - gate = QubitPermutationGate([1, 0, 2, 3]) +rs = np.random.RandomState(seed=1234) + + +@pytest.mark.parametrize('permutation', [rs.permutation(i) for i in range(3, 7)]) +def test_permutation_gate_consistent_protocols(permutation): + gate = QubitPermutationGate(list(permutation)) cirq.testing.assert_implements_consistent_protocols(gate) @@ -98,6 +103,8 @@ def test_permutation_gate_maps(maps, permutation): permutationOp = cirq.QubitPermutationGate(permutation).on(*qs) circuit = cirq.Circuit(permutationOp) cirq.testing.assert_equivalent_computational_basis_map(maps, circuit) + circuit = cirq.Circuit(cirq.I.on_each(*qs), cirq.decompose(permutationOp)) + cirq.testing.assert_equivalent_computational_basis_map(maps, circuit) def test_setters_deprecated(): From 68c8ce9b945c4c696c1f4c19975157e86285a115 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 17 Mar 2022 11:22:52 -0700 Subject: [PATCH 2/3] Change sorting to use odd-even sort --- cirq-core/cirq/ops/permutation_gate.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/ops/permutation_gate.py b/cirq-core/cirq/ops/permutation_gate.py index 5394763d57a..8fba79bcb37 100644 --- a/cirq-core/cirq/ops/permutation_gate.py +++ b/cirq-core/cirq/ops/permutation_gate.py @@ -75,13 +75,23 @@ def _has_unitary_(self): return True def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': - qubit_ids = [*range(len(qubits))] - for i in range(len(qubits)): - q_i = qubit_ids.index(self._permutation.index(i)) - for j in range(q_i, i, -1): - yield swap_gates.SWAP(qubits[j], qubits[j - 1]) - qubit_ids[j], qubit_ids[j - 1] = qubit_ids[j - 1], qubit_ids[j] - assert self._permutation[qubit_ids[i]] == i + n = len(qubits) + qubit_ids = [*range(n)] + is_sorted = False + + def _swap_if_out_of_order(idx: int) -> bool: + nonlocal is_sorted + if self._permutation[qubit_ids[idx]] > self._permutation[qubit_ids[idx + 1]]: + yield swap_gates.SWAP(qubits[idx], qubits[idx + 1]) + qubit_ids[idx + 1], qubit_ids[idx] = qubit_ids[idx], qubit_ids[idx + 1] + is_sorted = False + + while not is_sorted: + is_sorted = True + for i in range(0, n - 1, 2): + yield from _swap_if_out_of_order(i) + for i in range(1, n - 1, 2): + yield from _swap_if_out_of_order(i) def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'): # Compute the permutation index list. From ed002f4935d915f487cb8e59c6426be8ebe250f5 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 17 Mar 2022 11:30:54 -0700 Subject: [PATCH 3/3] Fix mypy types --- cirq-core/cirq/ops/permutation_gate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/permutation_gate.py b/cirq-core/cirq/ops/permutation_gate.py index 8fba79bcb37..919e84a3a62 100644 --- a/cirq-core/cirq/ops/permutation_gate.py +++ b/cirq-core/cirq/ops/permutation_gate.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING +from typing import Any, Dict, Iterable, Sequence, Tuple, TYPE_CHECKING from cirq import protocols, value from cirq._compat import deprecated @@ -79,7 +79,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': qubit_ids = [*range(n)] is_sorted = False - def _swap_if_out_of_order(idx: int) -> bool: + def _swap_if_out_of_order(idx: int) -> Iterable['cirq.Operation']: nonlocal is_sorted if self._permutation[qubit_ids[idx]] > self._permutation[qubit_ids[idx + 1]]: yield swap_gates.SWAP(qubits[idx], qubits[idx + 1])