diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index a663923f3f5..0a35b2b5687 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -332,6 +332,7 @@ ZPowGate, ZZ, ZZPowGate, + UniformSuperpositionGate, ) from cirq.transformers import ( diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 4880046618a..65dea9c7587 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -247,6 +247,7 @@ def _symmetricalqidpair(qids): 'ZipLongest': cirq.ZipLongest, 'ZPowGate': cirq.ZPowGate, 'ZZPowGate': cirq.ZZPowGate, + 'UniformSuperpositionGate': cirq.UniformSuperpositionGate, # Old types, only supported for backwards-compatibility 'BooleanHamiltonian': _boolean_hamiltonian_gate_op, # Removed in v0.15 'CrossEntropyResult': _cross_entropy_result, # Removed in v0.16 diff --git a/cirq-core/cirq/ops/__init__.py b/cirq-core/cirq/ops/__init__.py index 5cadb6ad9af..25db2fc710a 100644 --- a/cirq-core/cirq/ops/__init__.py +++ b/cirq-core/cirq/ops/__init__.py @@ -217,3 +217,5 @@ from cirq.ops.state_preparation_channel import StatePreparationChannel from cirq.ops.control_values import AbstractControlValues, ProductOfSums, SumOfProducts + +from cirq.ops.uniform_superposition_gate import UniformSuperpositionGate diff --git a/cirq-core/cirq/ops/uniform_superposition_gate.py b/cirq-core/cirq/ops/uniform_superposition_gate.py new file mode 100644 index 00000000000..87349482704 --- /dev/null +++ b/cirq-core/cirq/ops/uniform_superposition_gate.py @@ -0,0 +1,123 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Any, Dict, TYPE_CHECKING + +import numpy as np +from cirq.ops.common_gates import H, ry +from cirq.ops.pauli_gates import X +from cirq.ops import raw_types + + +if TYPE_CHECKING: + import cirq + + +class UniformSuperpositionGate(raw_types.Gate): + r"""Creates a uniform superposition state on the states $[0, M)$ + The gate creates the state $\frac{1}{\sqrt{M}}\sum_{j=0}^{M-1}\ket{j}$ + (where $1\leq M \leq 2^n$), using n qubits, according to the Shukla-Vedula algorithm [SV24]. + References: + [SV24] + [An efficient quantum algorithm for preparation of uniform quantum superposition + states](https://arxiv.org/abs/2306.11747) + """ + + def __init__(self, m_value: int, num_qubits: int) -> None: + """Initializes UniformSuperpositionGate. + + Args: + m_value: The number of computational basis states. + num_qubits: The number of qubits used. + + Raises: + ValueError: If `m_value` is not a positive integer, or + if `num_qubits` is not an integer greater than or equal to log2(m_value). + """ + if not (isinstance(m_value, int) and (m_value > 0)): + raise ValueError("m_value must be a positive integer.") + log_two_m_value = m_value.bit_length() + + if (m_value & (m_value - 1)) == 0: + log_two_m_value = log_two_m_value - 1 + if not (isinstance(num_qubits, int) and (num_qubits >= log_two_m_value)): + raise ValueError( + "num_qubits must be an integer greater than or equal to log2(m_value)." + ) + self._m_value = m_value + self._num_qubits = num_qubits + + def _decompose_(self, qubits: Sequence["cirq.Qid"]) -> "cirq.OP_TREE": + """Decomposes the gate into a sequence of standard gates. + Implements the construction from https://arxiv.org/pdf/2306.11747. + """ + qreg = list(qubits) + qreg.reverse() + + if self._m_value == 1: # if m_value is 1, do nothing + return + if (self._m_value & (self._m_value - 1)) == 0: # if m_value is an integer power of 2 + m = self._m_value.bit_length() - 1 + yield H.on_each(qreg[:m]) + return + k = self._m_value.bit_length() + l_value = [] + for i in range(self._m_value.bit_length()): + if (self._m_value >> i) & 1: + l_value.append(i) # Locations of '1's + + yield X.on_each(qreg[q_bit] for q_bit in l_value[1:k]) + m_current = 2 ** (l_value[0]) + theta = -2 * np.arccos(np.sqrt(m_current / self._m_value)) + if l_value[0] > 0: # if m_value is even + yield H.on_each(qreg[: l_value[0]]) + + yield ry(theta).on(qreg[l_value[1]]) + + for i in range(l_value[0], l_value[1]): + yield H(qreg[i]).controlled_by(qreg[l_value[1]], control_values=[False]) + + for m in range(1, len(l_value) - 1): + theta = -2 * np.arccos(np.sqrt(2 ** l_value[m] / (self._m_value - m_current))) + yield ry(theta).on(qreg[l_value[m + 1]]).controlled_by( + qreg[l_value[m]], control_values=[0] + ) + for i in range(l_value[m], l_value[m + 1]): + yield H.on(qreg[i]).controlled_by(qreg[l_value[m + 1]], control_values=[0]) + + m_current = m_current + 2 ** (l_value[m]) + + def num_qubits(self) -> int: + return self._num_qubits + + @property + def m_value(self) -> int: + return self._m_value + + def __eq__(self, other): + if isinstance(other, UniformSuperpositionGate): + return (self._m_value == other._m_value) and (self._num_qubits == other._num_qubits) + return False + + def __repr__(self) -> str: + return f'UniformSuperpositionGate(m_value={self._m_value}, num_qubits={self._num_qubits})' + + def _json_dict_(self) -> Dict[str, Any]: + d = {} + d['m_value'] = self._m_value + d['num_qubits'] = self._num_qubits + return d + + def __str__(self) -> str: + return f'UniformSuperpositionGate(m_value={self._m_value}, num_qubits={self._num_qubits})' diff --git a/cirq-core/cirq/ops/uniform_superposition_gate_test.py b/cirq-core/cirq/ops/uniform_superposition_gate_test.py new file mode 100644 index 00000000000..6f3e472d19a --- /dev/null +++ b/cirq-core/cirq/ops/uniform_superposition_gate_test.py @@ -0,0 +1,94 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import cirq + + +@pytest.mark.parametrize( + ["m", "n"], + [[int(m), n] for n in range(3, 7) for m in np.random.randint(1, 1 << n, size=3)] + + [(1, 2), (4, 2), (6, 3), (7, 3)], +) +def test_generated_unitary_is_uniform(m: int, n: int) -> None: + r"""The code checks that the unitary matrix corresponds to the generated uniform superposition + states (see uniform_superposition_gate.py). It is enough to check that the + first colum of the unitary matrix (which corresponds to the action of the gate on + $\ket{0}^n$ is $\frac{1}{\sqrt{M}} [1 1 \cdots 1 0 \cdots 0]^T$, where the first $M$ + entries are all "1"s (excluding the normalization factor of $\frac{1}{\sqrt{M}}$ and the + remaining $2^n-M$ entries are all "0"s. + """ + gate = cirq.UniformSuperpositionGate(m, n) + matrix = np.array(cirq.unitary(gate)) + np.testing.assert_allclose( + matrix[:, 0], (1 / np.sqrt(m)) * np.array([1] * m + [0] * (2**n - m)), atol=1e-8 + ) + + +@pytest.mark.parametrize(["m", "n"], [(1, 1), (-2, 1), (-3.1, 2), (6, -4), (5, 6.1)]) +def test_incompatible_m_value_and_qubit_args(m: int, n: int) -> None: + r"""The code checks that test errors are raised if the arguments m (number of + superposition states and n (number of qubits) are positive integers and are compatible + (i.e., n >= log2(m)). + """ + + if not (isinstance(m, int)): + with pytest.raises(ValueError, match="m_value must be a positive integer."): + cirq.UniformSuperpositionGate(m, n) + elif not (isinstance(n, int)): + with pytest.raises( + ValueError, + match="num_qubits must be an integer greater than or equal to log2\\(m_value\\).", + ): + cirq.UniformSuperpositionGate(m, n) + elif m < 1: + with pytest.raises(ValueError, match="m_value must be a positive integer."): + cirq.UniformSuperpositionGate(int(m), int(n)) + elif n < np.log2(m): + with pytest.raises( + ValueError, + match="num_qubits must be an integer greater than or equal to log2\\(m_value\\).", + ): + cirq.UniformSuperpositionGate(m, n) + + +def test_repr(): + assert ( + repr(cirq.UniformSuperpositionGate(7, 3)) + == 'UniformSuperpositionGate(m_value=7, num_qubits=3)' + ) + + +def test_uniform_superposition_gate_json_dict(): + assert cirq.UniformSuperpositionGate(7, 3)._json_dict_() == {'m_value': 7, 'num_qubits': 3} + + +def test_str(): + assert ( + str(cirq.UniformSuperpositionGate(7, 3)) + == 'UniformSuperpositionGate(m_value=7, num_qubits=3)' + ) + + +@pytest.mark.parametrize(["m", "n"], [(5, 3), (10, 4)]) +def test_eq(m: int, n: int) -> None: + a = cirq.UniformSuperpositionGate(m, n) + b = cirq.UniformSuperpositionGate(m, n) + c = cirq.UniformSuperpositionGate(m + 1, n) + d = cirq.X + assert a.m_value == b.m_value + assert a.__eq__(b) + assert not (a.__eq__(c)) + assert not (a.__eq__(d)) diff --git a/cirq-core/cirq/protocols/json_test_data/UniformSuperpositionGate.json b/cirq-core/cirq/protocols/json_test_data/UniformSuperpositionGate.json new file mode 100644 index 00000000000..52203d8538e --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/UniformSuperpositionGate.json @@ -0,0 +1,5 @@ +{ + "cirq_type": "UniformSuperpositionGate", + "m_value": 7, + "num_qubits": 3 +} diff --git a/cirq-core/cirq/protocols/json_test_data/UniformSuperpositionGate.repr b/cirq-core/cirq/protocols/json_test_data/UniformSuperpositionGate.repr new file mode 100644 index 00000000000..62b2bdac0f2 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/UniformSuperpositionGate.repr @@ -0,0 +1 @@ + cirq.UniformSuperpositionGate(m_value=7, num_qubits=3)