Skip to content

cirq.measure - accept list arguments #5411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions cirq-core/cirq/ops/measure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Callable, Iterable, List, overload, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand Down Expand Up @@ -81,8 +81,30 @@ def measure_paulistring_terms(
return [PauliMeasurementGate([pauli_basis[q]], key=key_func(q)).on(q) for q in pauli_basis]


# pylint: disable=function-redefined


@overload
def measure(
*target: raw_types.Qid,
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
invert_mask: Tuple[bool, ...] = (),
) -> raw_types.Operation:
pass


@overload
def measure(
__target: Iterable[raw_types.Qid],
*,
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
invert_mask: Tuple[bool, ...] = (),
) -> raw_types.Operation:
pass


def measure(
*target: 'cirq.Qid',
*target,
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
invert_mask: Tuple[bool, ...] = (),
) -> raw_types.Operation:
Expand All @@ -92,6 +114,8 @@ def measure(

Args:
*target: The qubits that the measurement gate should measure.
These can be specified as separate function arguments or
with a single argument for an iterable of qubits.
key: The string key of the measurement. If this is None, it defaults
to a comma-separated list of the target qubits' str values.
invert_mask: A list of Truthy or Falsey values indicating whether
Expand All @@ -104,7 +128,13 @@ def measure(
Raises:
ValueError: If the qubits are not instances of Qid.
"""
for qubit in target:
one_iterable_arg: bool = (
len(target) == 1
and isinstance(target[0], Iterable)
and not isinstance(target[0], (bytes, str, np.ndarray))
)
targets = tuple(target[0]) if one_iterable_arg else target
for qubit in targets:
if isinstance(qubit, np.ndarray):
raise ValueError(
'measure() was called a numpy ndarray. Perhaps you meant '
Expand All @@ -114,24 +144,48 @@ def measure(
raise ValueError('measure() was called with type different than Qid.')

if key is None:
key = _default_measurement_key(target)
qid_shape = protocols.qid_shape(target)
return MeasurementGate(len(target), key, invert_mask, qid_shape).on(*target)
key = _default_measurement_key(targets)
qid_shape = protocols.qid_shape(targets)
return MeasurementGate(len(targets), key, invert_mask, qid_shape).on(*targets)


@overload
def measure_each(
*qubits: raw_types.Qid, key_func: Callable[[raw_types.Qid], str] = str
) -> List[raw_types.Operation]:
pass


@overload
def measure_each(
*qubits: 'cirq.Qid', key_func: Callable[[raw_types.Qid], str] = str
__qubits: Iterable[raw_types.Qid], *, key_func: Callable[[raw_types.Qid], str] = str
) -> List[raw_types.Operation]:
pass


def measure_each(
*qubits, key_func: Callable[[raw_types.Qid], str] = str
) -> List[raw_types.Operation]:
"""Returns a list of operations individually measuring the given qubits.

The qubits are measured in the computational basis.

Args:
*qubits: The qubits to measure.
*qubits: The qubits to measure. These can be passed as separate
function arguments or as a one-argument iterable of qubits.
key_func: Determines the key of the measurements of each qubit. Takes
the qubit and returns the key for that qubit. Defaults to str.

Returns:
A list of operations individually measuring the given qubits.
"""
return [MeasurementGate(1, key_func(q), qid_shape=(q.dimension,)).on(q) for q in qubits]
one_iterable_arg: bool = (
len(qubits) == 1
and isinstance(qubits[0], Iterable)
and not isinstance(qubits[0], (bytes, str))
)
qubitsequence = qubits[0] if one_iterable_arg else qubits
return [MeasurementGate(1, key_func(q), qid_shape=(q.dimension,)).on(q) for q in qubitsequence]


# pylint: enable=function-redefined
23 changes: 22 additions & 1 deletion cirq-core/cirq/ops/measure_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ def test_measure_qubits():
with pytest.raises(ValueError, match='empty set of qubits'):
_ = cirq.measure()

with pytest.raises(ValueError, match='empty set of qubits'):
_ = cirq.measure([])

assert cirq.measure(a) == cirq.MeasurementGate(num_qubits=1, key='a').on(a)
assert cirq.measure([a]) == cirq.MeasurementGate(num_qubits=1, key='a').on(a)
assert cirq.measure(a, b) == cirq.MeasurementGate(num_qubits=2, key='a,b').on(a, b)
assert cirq.measure([a, b]) == cirq.MeasurementGate(num_qubits=2, key='a,b').on(a, b)
qubit_generator = (q for q in (a, b))
assert cirq.measure(qubit_generator) == cirq.MeasurementGate(num_qubits=2, key='a,b').on(a, b)
assert cirq.measure(b, a) == cirq.MeasurementGate(num_qubits=2, key='b,a').on(b, a)
assert cirq.measure(a, key='b') == cirq.MeasurementGate(num_qubits=1, key='b').on(a)
assert cirq.measure(a, invert_mask=(True,)) == cirq.MeasurementGate(
Expand All @@ -36,21 +43,35 @@ def test_measure_qubits():
assert cirq.measure(*cirq.LineQid.for_qid_shape((1, 2, 3)), key='a') == cirq.MeasurementGate(
num_qubits=3, key='a', qid_shape=(1, 2, 3)
).on(*cirq.LineQid.for_qid_shape((1, 2, 3)))
assert cirq.measure(cirq.LineQid.for_qid_shape((1, 2, 3)), key='a') == cirq.MeasurementGate(
num_qubits=3, key='a', qid_shape=(1, 2, 3)
).on(*cirq.LineQid.for_qid_shape((1, 2, 3)))

with pytest.raises(ValueError, match='ndarray'):
_ = cirq.measure(np.ndarray([1, 0]))
_ = cirq.measure(np.array([1, 0]))

with pytest.raises(ValueError, match='Qid'):
_ = cirq.measure("bork")

with pytest.raises(ValueError, match='Qid'):
_ = cirq.measure([a, [b]])

with pytest.raises(ValueError, match='Qid'):
_ = cirq.measure([a], [b])


def test_measure_each():
a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')

assert cirq.measure_each() == []
assert cirq.measure_each([]) == []
assert cirq.measure_each(a) == [cirq.measure(a)]
assert cirq.measure_each([a]) == [cirq.measure(a)]
assert cirq.measure_each(a, b) == [cirq.measure(a), cirq.measure(b)]
assert cirq.measure_each([a, b]) == [cirq.measure(a), cirq.measure(b)]
qubit_generator = (q for q in (a, b))
assert cirq.measure_each(qubit_generator) == [cirq.measure(a), cirq.measure(b)]
assert cirq.measure_each(a.with_dimension(3), b.with_dimension(3)) == [
cirq.measure(a.with_dimension(3)),
cirq.measure(b.with_dimension(3)),
Expand Down