-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Create a generalized uniform superposition state gate #6506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
93cb5cc
45d204b
879c804
4026261
6d8a4d7
ab2e2b6
eb9f19f
3303a94
24ef8a8
1145e11
5b4ecc3
659adf2
01a5fe2
9ca34bf
8ce754b
a24cdec
e811f3a
af83138
8bcfdd0
998a765
0930e1d
3d3abb5
4b31c3b
803f178
f370dad
4df3af8
90ceb82
2fec2ba
09e5a67
46e035c
8708d04
06b35ab
e3a1ff6
f63207c
1386270
b3d1f34
599c94e
a4e80a5
72b109b
c113212
fb755d1
c2765a9
cfc4561
318f76a
6dc054d
4244373
1018cc5
9472a77
7638852
a71fe41
519f46d
fab4973
d1ec56a
9bceb57
ff7f884
673d5fb
a34dca8
7036c1b
9e61adb
883209a
a00e5d7
c2def1e
ff8bcd5
607cb40
8279b9a
77c951a
c789426
24cd113
b4a8207
8def6df
cfcf363
41aa66e
41803bf
0fcf60b
ac491bc
389260e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -332,6 +332,7 @@ | |
ZPowGate, | ||
ZZ, | ||
ZZPowGate, | ||
UniformSuperpositionGate, | ||
) | ||
|
||
from cirq.transformers import ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright 2024 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Sequence, Any, Dict, TYPE_CHECKING | ||
|
||
import numpy as np | ||
from cirq.ops.common_gates import H, ry | ||
from cirq.ops.pauli_gates import X | ||
from cirq.ops import raw_types | ||
|
||
|
||
if TYPE_CHECKING: | ||
import cirq | ||
|
||
|
||
class UniformSuperpositionGate(raw_types.Gate): | ||
r"""Creates a uniform superposition state on the states $[0, M)$ | ||
The gate creates the state $\frac{1}{\sqrt{M}}\sum_{j=0}^{M-1}\ket{j}$ | ||
(where $1\leq M \leq 2^n$), using n qubits, according to the Shukla-Vedula algorithm [SV24]. | ||
References: | ||
[SV24] | ||
[An efficient quantum algorithm for preparation of uniform quantum superposition | ||
states](https://arxiv.org/abs/2306.11747) | ||
""" | ||
|
||
def __init__(self, m_value: int, num_qubits: int) -> None: | ||
"""Initializes UniformSuperpositionGate. | ||
|
||
Args: | ||
m_value: The number of computational basis states. | ||
num_qubits: The number of qubits used. | ||
|
||
Raises: | ||
ValueError: If `m_value` is not a positive integer, or | ||
if `num_qubits` is not an integer greater than or equal to log2(m_value). | ||
""" | ||
if not (isinstance(m_value, int) and (m_value > 0)): | ||
raise ValueError("m_value must be a positive integer.") | ||
log_two_m_value = m_value.bit_length() | ||
|
||
if (m_value & (m_value - 1)) == 0: | ||
log_two_m_value = log_two_m_value - 1 | ||
if not (isinstance(num_qubits, int) and (num_qubits >= log_two_m_value)): | ||
raise ValueError( | ||
"num_qubits must be an integer greater than or equal to log2(m_value)." | ||
) | ||
self._m_value = m_value | ||
self._num_qubits = num_qubits | ||
|
||
def _decompose_(self, qubits: Sequence["cirq.Qid"]) -> "cirq.OP_TREE": | ||
"""Decomposes the gate into a sequence of standard gates. | ||
Implements the construction from https://arxiv.org/pdf/2306.11747. | ||
""" | ||
qreg = list(qubits) | ||
qreg.reverse() | ||
|
||
if self._m_value == 1: # if m_value is 1, do nothing | ||
return | ||
if (self._m_value & (self._m_value - 1)) == 0: # if m_value is an integer power of 2 | ||
m = self._m_value.bit_length() - 1 | ||
yield H.on_each(qreg[:m]) | ||
return | ||
k = self._m_value.bit_length() | ||
l_value = [] | ||
for i in range(self._m_value.bit_length()): | ||
if (self._m_value >> i) & 1: | ||
l_value.append(i) # Locations of '1's | ||
|
||
yield X.on_each(qreg[q_bit] for q_bit in l_value[1:k]) | ||
m_current = 2 ** (l_value[0]) | ||
theta = -2 * np.arccos(np.sqrt(m_current / self._m_value)) | ||
if l_value[0] > 0: # if m_value is even | ||
yield H.on_each(qreg[: l_value[0]]) | ||
|
||
yield ry(theta).on(qreg[l_value[1]]) | ||
|
||
for i in range(l_value[0], l_value[1]): | ||
yield H(qreg[i]).controlled_by(qreg[l_value[1]], control_values=[False]) | ||
|
||
for m in range(1, len(l_value) - 1): | ||
theta = -2 * np.arccos(np.sqrt(2 ** l_value[m] / (self._m_value - m_current))) | ||
yield ry(theta).on(qreg[l_value[m + 1]]).controlled_by( | ||
qreg[l_value[m]], control_values=[0] | ||
) | ||
for i in range(l_value[m], l_value[m + 1]): | ||
yield H.on(qreg[i]).controlled_by(qreg[l_value[m + 1]], control_values=[0]) | ||
|
||
m_current = m_current + 2 ** (l_value[m]) | ||
|
||
def num_qubits(self) -> int: | ||
return self._num_qubits | ||
|
||
@property | ||
def m_value(self) -> int: | ||
return self._m_value | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, UniformSuperpositionGate): | ||
return (self._m_value == other._m_value) and (self._num_qubits == other._num_qubits) | ||
return False | ||
|
||
def __repr__(self) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does any test fail with the default repr? if not please remove this function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function was removed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||
return f'UniformSuperpositionGate(m_value={self._m_value}, num_qubits={self._num_qubits})' | ||
|
||
def _json_dict_(self) -> Dict[str, Any]: | ||
d = {} | ||
d['m_value'] = self._m_value | ||
d['num_qubits'] = self._num_qubits | ||
return d | ||
|
||
def __str__(self) -> str: | ||
return f'UniformSuperpositionGate(m_value={self._m_value}, num_qubits={self._num_qubits})' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright 2024 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import pytest | ||
import cirq | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["m", "n"], | ||
[[int(m), n] for n in range(3, 7) for m in np.random.randint(1, 1 << n, size=3)] | ||
+ [(1, 2), (4, 2), (6, 3), (7, 3)], | ||
) | ||
def test_generated_unitary_is_uniform(m: int, n: int) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tests should test for one thing. please split this test into 2 tests. one for correctness (checking the unitary) and another for argument validation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestions. I have split the test into two separate tests as suggested. |
||
r"""The code checks that the unitary matrix corresponds to the generated uniform superposition | ||
states (see uniform_superposition_gate.py). It is enough to check that the | ||
first colum of the unitary matrix (which corresponds to the action of the gate on | ||
$\ket{0}^n$ is $\frac{1}{\sqrt{M}} [1 1 \cdots 1 0 \cdots 0]^T$, where the first $M$ | ||
entries are all "1"s (excluding the normalization factor of $\frac{1}{\sqrt{M}}$ and the | ||
remaining $2^n-M$ entries are all "0"s. | ||
""" | ||
gate = cirq.UniformSuperpositionGate(m, n) | ||
matrix = np.array(cirq.unitary(gate)) | ||
np.testing.assert_allclose( | ||
matrix[:, 0], (1 / np.sqrt(m)) * np.array([1] * m + [0] * (2**n - m)), atol=1e-8 | ||
) | ||
|
||
|
||
@pytest.mark.parametrize(["m", "n"], [(1, 1), (-2, 1), (-3.1, 2), (6, -4), (5, 6.1)]) | ||
def test_incompatible_m_value_and_qubit_args(m: int, n: int) -> None: | ||
r"""The code checks that test errors are raised if the arguments m (number of | ||
superposition states and n (number of qubits) are positive integers and are compatible | ||
(i.e., n >= log2(m)). | ||
""" | ||
|
||
if not (isinstance(m, int)): | ||
with pytest.raises(ValueError, match="m_value must be a positive integer."): | ||
cirq.UniformSuperpositionGate(m, n) | ||
elif not (isinstance(n, int)): | ||
with pytest.raises( | ||
ValueError, | ||
match="num_qubits must be an integer greater than or equal to log2\\(m_value\\).", | ||
): | ||
cirq.UniformSuperpositionGate(m, n) | ||
elif m < 1: | ||
with pytest.raises(ValueError, match="m_value must be a positive integer."): | ||
cirq.UniformSuperpositionGate(int(m), int(n)) | ||
elif n < np.log2(m): | ||
with pytest.raises( | ||
ValueError, | ||
match="num_qubits must be an integer greater than or equal to log2\\(m_value\\).", | ||
): | ||
cirq.UniformSuperpositionGate(m, n) | ||
|
||
|
||
def test_repr(): | ||
assert ( | ||
repr(cirq.UniformSuperpositionGate(7, 3)) | ||
== 'UniformSuperpositionGate(m_value=7, num_qubits=3)' | ||
) | ||
|
||
|
||
def test_uniform_superposition_gate_json_dict(): | ||
assert cirq.UniformSuperpositionGate(7, 3)._json_dict_() == {'m_value': 7, 'num_qubits': 3} | ||
|
||
|
||
def test_str(): | ||
assert ( | ||
str(cirq.UniformSuperpositionGate(7, 3)) | ||
== 'UniformSuperpositionGate(m_value=7, num_qubits=3)' | ||
) | ||
|
||
|
||
@pytest.mark.parametrize(["m", "n"], [(5, 3), (10, 4)]) | ||
def test_eq(m: int, n: int) -> None: | ||
a = cirq.UniformSuperpositionGate(m, n) | ||
b = cirq.UniformSuperpositionGate(m, n) | ||
c = cirq.UniformSuperpositionGate(m + 1, n) | ||
d = cirq.X | ||
assert a.m_value == b.m_value | ||
assert a.__eq__(b) | ||
assert not (a.__eq__(c)) | ||
assert not (a.__eq__(d)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"cirq_type": "UniformSuperpositionGate", | ||
"m_value": 7, | ||
"num_qubits": 3 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
cirq.UniformSuperpositionGate(m_value=7, num_qubits=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no need for docstring here or in repr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring removed.