Skip to content

Commit 75b3f40

Browse files
Adds tools for appending randomized measurement bases and processing renyi entropy from bitstring (#6664)
* add utilities for processing renyi entropy and appending randomized measurements + tests * nit: update test comments * address comments * fix np array shape which changes test solution * Address comments and fix bug * address comments * use zip and itertools and update to transformer * rm print * type check * Update cirq-core/cirq/qis/entropy.py Co-authored-by: Noureldin <[email protected]> * Update cirq-core/cirq/transformers/randomized_measurements.py Co-authored-by: Noureldin <[email protected]> * comments * line too long --------- Co-authored-by: Noureldin <[email protected]>
1 parent 3922a63 commit 75b3f40

File tree

6 files changed

+324
-0
lines changed

6 files changed

+324
-0
lines changed

cirq-core/cirq/qis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,4 @@
6060
average_error,
6161
decoherence_pauli_error,
6262
)
63+
from cirq.qis.entropy import process_renyi_entropy_from_bitstrings

cirq-core/cirq/qis/entropy.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from concurrent.futures import ThreadPoolExecutor
16+
from collections.abc import Sequence
17+
from itertools import product
18+
from typing import Any, Optional
19+
20+
import numpy as np
21+
import numpy.typing as npt
22+
23+
24+
def _get_hamming_distance(
25+
bitstring_1: npt.NDArray[np.int8], bitstring_2: npt.NDArray[np.int8]
26+
) -> int:
27+
"""Calculates the Hamming distance between two bitstrings.
28+
Args:
29+
bitstring_1: Bitstring 1
30+
bitstring_2: Bitstring 2
31+
Returns: The Hamming distance
32+
"""
33+
return (bitstring_1 ^ bitstring_2).sum().item()
34+
35+
36+
def _bitstrings_to_probs(
37+
bitstrings: npt.NDArray[np.int8],
38+
) -> tuple[npt.NDArray[np.int8], npt.NDArray[Any]]:
39+
"""Given a list of bitstrings from different measurements returns a probability distribution.
40+
Args:
41+
bitstrings: The bitstring
42+
Returns:
43+
A tuple of bitstrings and their corresponding probabilities.
44+
"""
45+
46+
num_shots = bitstrings.shape[0]
47+
unique_bitstrings, counts = np.unique(bitstrings, return_counts=True, axis=0)
48+
probs = counts / num_shots
49+
50+
return (unique_bitstrings, probs)
51+
52+
53+
def _bitstring_format_helper(
54+
measured_bitstrings: npt.NDArray[np.int8], subsystem: Sequence[int] | None = None
55+
) -> npt.NDArray[np.int8]:
56+
"""Formats the bitstring for analysis based on the selected subsystem.
57+
Args:
58+
measured_bitstrings: List of sampled measurement outcomes as a numpy array of bitstrings.
59+
subsystem: Subsystem of interest
60+
Returns: The bitstring string for the subsystem
61+
"""
62+
if subsystem is None:
63+
return measured_bitstrings
64+
65+
return measured_bitstrings[:, :, subsystem]
66+
67+
68+
def _compute_bitstrings_contribution_to_purity(bitstrings: npt.NDArray[np.int8]) -> float:
69+
"""Computes the contribution to the purity of the bitstrings.
70+
Args:
71+
bitstrings: The bitstrings measured using the same unitary operators
72+
Returns: The purity of the bitstring
73+
"""
74+
75+
bitstrings, probs = _bitstrings_to_probs(bitstrings)
76+
purity = 0
77+
for (s, p), (s_prime, p_prime) in product(zip(bitstrings, probs), repeat=2):
78+
purity += (-2.0) ** float(-_get_hamming_distance(s, s_prime)) * p * p_prime
79+
80+
return purity * 2 ** (bitstrings.shape[-1])
81+
82+
83+
def process_renyi_entropy_from_bitstrings(
84+
measured_bitstrings: npt.NDArray[np.int8],
85+
subsystem: tuple[int] | None = None,
86+
pool: Optional[ThreadPoolExecutor] = None,
87+
) -> float:
88+
"""Compute the Rényi entropy of an array of bitstrings.
89+
Args:
90+
measured_bitstrings: List of sampled measurement outcomes as a numpy array of bitstrings.
91+
subsystem: Subsystem of interest
92+
pool: ThreadPoolExecutor used to paralelleize the computation.
93+
94+
Returns:
95+
A float indicating the computed entropy.
96+
"""
97+
bitstrings = _bitstring_format_helper(measured_bitstrings, subsystem)
98+
num_shots = bitstrings.shape[1]
99+
num_qubits = bitstrings.shape[-1]
100+
101+
if num_shots == 1:
102+
return 0
103+
104+
if pool is not None:
105+
purities = list(pool.map(_compute_bitstrings_contribution_to_purity, list(bitstrings)))
106+
purity = np.mean(purities)
107+
108+
else:
109+
purity = np.mean(
110+
[_compute_bitstrings_contribution_to_purity(bitstring) for bitstring in bitstrings]
111+
)
112+
113+
purity_unbiased = purity * num_shots / (num_shots - 1) - (2**num_qubits) / (num_shots - 1)
114+
115+
return -np.log2(purity_unbiased)

cirq-core/cirq/qis/entropy_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from concurrent.futures import ThreadPoolExecutor
16+
import pytest
17+
import numpy as np
18+
19+
from cirq.qis.entropy import process_renyi_entropy_from_bitstrings
20+
21+
22+
@pytest.mark.parametrize('pool', [None, ThreadPoolExecutor(max_workers=1)])
23+
def test_process_renyi_entropy_from_bitstrings(pool):
24+
bitstrings = np.array(
25+
[
26+
[[0, 1, 1, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 1]],
27+
[[0, 1, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 1]],
28+
[[0, 0, 1, 1], [0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 1, 1]],
29+
[[1, 0, 1, 1], [0, 1, 0, 0], [0, 1, 0, 1], [0, 1, 0, 0]],
30+
[[1, 0, 1, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
31+
]
32+
)
33+
substsytem = (0, 1)
34+
entropy = process_renyi_entropy_from_bitstrings(bitstrings, substsytem, pool)
35+
assert entropy == 0.5145731728297583
36+
37+
38+
def test_process_renyi_entropy_from_bitstrings_safeguards_against_divide_by_0_error():
39+
bitstrings = np.array([[[0, 1, 1, 0]], [[0, 1, 1, 0]], [[0, 0, 1, 1]]])
40+
41+
entropy = process_renyi_entropy_from_bitstrings(bitstrings)
42+
assert entropy == 0

cirq-core/cirq/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,5 @@
134134
SqrtCZGaugeTransformer,
135135
SqrtISWAPGaugeTransformer,
136136
)
137+
138+
from cirq.transformers.randomized_measurements import RandomizedMeasurements
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections.abc import Sequence
16+
from typing import Any, Literal
17+
18+
import cirq
19+
import numpy as np
20+
from cirq.transformers import transformer_api
21+
22+
23+
@transformer_api.transformer
24+
class RandomizedMeasurements:
25+
"""A transformer that appends a moment of random rotations to map qubits to
26+
random pauli bases."""
27+
28+
def __init__(self, subsystem: Sequence[int] | None = None):
29+
"""Class structure for performing and analyzing a general randomized measurement protocol.
30+
For more details on the randomized measurement toolbox see https://arxiv.org/abs/2203.11374
31+
32+
Args:
33+
subsystem: The specific subsystem (e.g qubit index) to measure in random basis
34+
"""
35+
self.subsystem = subsystem
36+
37+
def __call__(
38+
self,
39+
circuit: 'cirq.AbstractCircuit',
40+
rng: np.random.Generator | None = None,
41+
*,
42+
context: transformer_api.TransformerContext | None = None,
43+
):
44+
"""Apply the transformer to the given circuit. Given an input circuit returns
45+
a list of circuits with the pre-measurement unitaries. If no arguments are specified,
46+
it will default to computing the entropy of the entire circuit.
47+
48+
Args:
49+
circuit: The circuit to add randomized measurements to.
50+
rng: Random number generator.
51+
context: Not used; to satisfy transformer API.
52+
53+
Returns:
54+
List of circuits with pre-measurement unitaries and measurements added
55+
"""
56+
if rng is None:
57+
rng = np.random.default_rng()
58+
59+
qubits = sorted(circuit.all_qubits())
60+
num_qubits = len(qubits)
61+
62+
pre_measurement_unitaries_list = self._generate_unitaries_list(rng, num_qubits)
63+
pre_measurement_moment = self.unitaries_to_moment(pre_measurement_unitaries_list, qubits)
64+
65+
return cirq.Circuit.from_moments(
66+
*circuit.moments, pre_measurement_moment, cirq.M(*qubits, key='m')
67+
)
68+
69+
def _generate_unitaries_list(self, rng: np.random.Generator, num_qubits: int) -> Sequence[Any]:
70+
"""Generates a list of pre-measurement unitaries."""
71+
72+
pauli_strings = rng.choice(["X", "Y", "Z"], size=num_qubits)
73+
74+
if self.subsystem is not None:
75+
for i in range(pauli_strings.shape[0]):
76+
if i not in self.subsystem:
77+
pauli_strings[i] = np.array("Z")
78+
79+
return pauli_strings.tolist()
80+
81+
def unitaries_to_moment(
82+
self, unitaries: Sequence[Literal["X", "Y", "Z"]], qubits: Sequence[Any]
83+
) -> 'cirq.Moment':
84+
"""Outputs the cirq moment associated with the pre-measurement rotations.
85+
Args:
86+
unitaries: List of pre-measurement unitaries
87+
qubits: List of qubits
88+
89+
Returns: The cirq moment associated with the pre-measurement rotations
90+
"""
91+
op_list: list[cirq.Operation] = []
92+
for idx, pauli in enumerate(unitaries):
93+
op_list.append(_pauli_basis_rotation(pauli).on(qubits[idx]))
94+
95+
return cirq.Moment.from_ops(*op_list)
96+
97+
98+
def _pauli_basis_rotation(basis: Literal["X", "Y", "Z"]) -> 'cirq.Gate':
99+
"""Given a measurement basis returns the associated rotation.
100+
Args:
101+
basis: Measurement basis
102+
Returns: The cirq gate for associated with measurement basis
103+
"""
104+
if basis == "X":
105+
return cirq.Ry(rads=-np.pi / 2)
106+
elif basis == "Y":
107+
return cirq.Rx(rads=np.pi / 2)
108+
elif basis == "Z":
109+
return cirq.I
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import cirq
16+
import cirq.transformers.randomized_measurements as rand_meas
17+
18+
19+
def test_randomized_measurements_appends_two_moments_on_returned_circuit():
20+
# Create a 4-qubit circuit
21+
q0, q1, q2, q3 = cirq.LineQubit.range(4)
22+
circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)])
23+
num_moments_pre = len(circuit.moments)
24+
25+
# Append randomized measurements to subsystem
26+
circuit = rand_meas.RandomizedMeasurements()(circuit)
27+
28+
num_moments_post = len(circuit.moments)
29+
assert num_moments_post == num_moments_pre + 2
30+
31+
32+
def test_append_randomized_measurements_leaves_qubits_not_in_specified_subsystem_unchanged():
33+
# Create a 4-qubit circuit
34+
q0, q1, q2, q3 = cirq.LineQubit.range(4)
35+
circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)])
36+
37+
# Append randomized measurements to subsystem
38+
circuit = rand_meas.RandomizedMeasurements(subsystem=(0, 1))(circuit)
39+
40+
# assert latter subsystems were not changed.
41+
assert circuit.operation_at(q2, 4) == cirq.I(q2)
42+
assert circuit.operation_at(q3, 4) == cirq.I(q3)
43+
44+
45+
def test_append_randomized_measurements_leaves_qubits_not_in_noncontinuous_subsystem_unchanged():
46+
# Create a 4-qubit circuit
47+
q0, q1, q2, q3 = cirq.LineQubit.range(4)
48+
circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)])
49+
50+
# Append randomized measurements to subsystem
51+
circuit = rand_meas.RandomizedMeasurements(subsystem=(0, 2))(circuit)
52+
53+
# assert latter subsystems were not changed.
54+
assert circuit.operation_at(q1, 4) == cirq.I(q1)
55+
assert circuit.operation_at(q3, 4) == cirq.I(q3)

0 commit comments

Comments
 (0)