Skip to content

Flip back to default use_repetition_ids=True in CircuitOperation #7237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 31 additions & 34 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

import math
import warnings
from functools import cached_property
from typing import (
Any,
Expand All @@ -48,7 +49,6 @@
if TYPE_CHECKING:
import cirq


INT_CLASSES = (int, np.integer)
INT_TYPE = Union[int, np.integer]
IntParam = Union[INT_TYPE, sympy.Expr]
Expand Down Expand Up @@ -123,8 +123,9 @@ def __init__(
use_repetition_ids: When True, any measurement key in the subcircuit
will have its path prepended with the repetition id for each
repetition. When False, this will not happen and the measurement
key will be repeated. When None, default to False unless the caller
passes `repetition_ids` explicitly.
key will be repeated. The default is True, but it will be changed
to False in the next release. Please pass an explicit argument
``use_repetition_ids=True`` to preserve the current behavior.
repeat_until: A condition that will be tested after each iteration of
the subcircuit. The subcircuit will repeat until condition returns
True, but will always run at least once, and the measurement key
Expand Down Expand Up @@ -161,8 +162,18 @@ def __init__(
self._repetitions = repetitions
self._repetition_ids = None if repetition_ids is None else list(repetition_ids)
if use_repetition_ids is None:
use_repetition_ids = repetition_ids is not None
self._use_repetition_ids = use_repetition_ids
if repetition_ids is None:
msg = (
"In cirq 1.6 the default value of `use_repetition_ids` will change to\n"
"`use_repetition_ids=False`. To make this warning go away, please pass\n"
"explicit `use_repetition_ids`, e.g., to preserve current behavior, use\n"
"\n"
" CircuitOperations(..., use_repetition_ids=True)"
)
warnings.warn(msg, FutureWarning)
self._use_repetition_ids = True
else:
self._use_repetition_ids = use_repetition_ids
if isinstance(self._repetitions, float):
if math.isclose(self._repetitions, round(self._repetitions)):
self._repetitions = round(self._repetitions)
Expand Down Expand Up @@ -270,9 +281,7 @@ def replace(self, **changes) -> cirq.CircuitOperation:
'repetition_ids': self.repetition_ids,
'parent_path': self.parent_path,
'extern_keys': self._extern_keys,
'use_repetition_ids': (
True if changes.get('repetition_ids') is not None else self.use_repetition_ids
),
'use_repetition_ids': self.use_repetition_ids,
'repeat_until': self.repeat_until,
**changes,
}
Expand Down Expand Up @@ -476,9 +485,11 @@ def __repr__(self):
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
if self.parent_path:
args += f'parent_path={proper_repr(self.parent_path)},\n'
if self.use_repetition_ids:
if self.repetition_ids != self._default_repetition_ids():
# Default repetition_ids need not be specified.
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
if not self.use_repetition_ids:
args += 'use_repetition_ids=False,\n'
if self.repeat_until:
args += f'repeat_until={self.repeat_until!r},\n'
indented_args = args.replace('\n', '\n ')
Expand All @@ -503,15 +514,14 @@ def dict_str(d: Mapping) -> str:
args.append(f'params={self.param_resolver.param_dict}')
if self.parent_path:
args.append(f'parent_path={self.parent_path}')
if self.use_repetition_ids:
if self.repetition_ids != self._default_repetition_ids():
args.append(f'repetition_ids={self.repetition_ids}')
else:
# Default repetition_ids need not be specified.
args.append(f'loops={self.repetitions}, use_repetition_ids=True')
if self.repetition_ids != self._default_repetition_ids():
# Default repetition_ids need not be specified.
args.append(f'repetition_ids={self.repetition_ids}')
elif self.repetitions != 1:
# Add loops if not using repetition_ids.
# Only add loops if we haven't added repetition_ids.
args.append(f'loops={self.repetitions}')
if not self.use_repetition_ids:
args.append('no_rep_ids')
if self.repeat_until:
args.append(f'until={self.repeat_until}')
if not args:
Expand Down Expand Up @@ -556,9 +566,10 @@ def _json_dict_(self):
'measurement_key_map': self.measurement_key_map,
'param_resolver': self.param_resolver,
'repetition_ids': self.repetition_ids,
'use_repetition_ids': self.use_repetition_ids,
'parent_path': self.parent_path,
}
if not self.use_repetition_ids:
resp['use_repetition_ids'] = False
if self.repeat_until:
resp['repeat_until'] = self.repeat_until
return resp
Expand Down Expand Up @@ -592,10 +603,7 @@ def _from_json_dict_(
# Methods for constructing a similar object with one field modified.

def repeat(
self,
repetitions: Optional[IntParam] = None,
repetition_ids: Optional[Sequence[str]] = None,
use_repetition_ids: Optional[bool] = None,
self, repetitions: Optional[IntParam] = None, repetition_ids: Optional[Sequence[str]] = None
) -> CircuitOperation:
"""Returns a copy of this operation repeated 'repetitions' times.
Each repetition instance will be identified by a single repetition_id.
Expand All @@ -606,10 +614,6 @@ def repeat(
defaults to the length of `repetition_ids`.
repetition_ids: List of IDs, one for each repetition. If unset,
defaults to `default_repetition_ids(repetitions)`.
use_repetition_ids: If given, this specifies the value for `use_repetition_ids`
of the resulting circuit operation. If not given, we enable ids if
`repetition_ids` is not None, and otherwise fall back to
`self.use_repetition_ids`.

Returns:
A copy of this operation repeated `repetitions` times with the
Expand All @@ -624,9 +628,6 @@ def repeat(
ValueError: Unexpected length of `repetition_ids`.
ValueError: Both `repetitions` and `repetition_ids` are None.
"""
if use_repetition_ids is None:
use_repetition_ids = True if repetition_ids is not None else self.use_repetition_ids

if repetitions is None:
if repetition_ids is None:
raise ValueError('At least one of repetitions and repetition_ids must be set')
Expand All @@ -640,7 +641,7 @@ def repeat(
expected_repetition_id_length: int = np.abs(repetitions)

if repetition_ids is None:
if use_repetition_ids:
if self.use_repetition_ids:
repetition_ids = default_repetition_ids(expected_repetition_id_length)
elif len(repetition_ids) != expected_repetition_id_length:
raise ValueError(
Expand All @@ -653,11 +654,7 @@ def repeat(

# The eventual number of repetitions of the returned CircuitOperation.
final_repetitions = protocols.mul(self.repetitions, repetitions)
return self.replace(
repetitions=final_repetitions,
repetition_ids=repetition_ids,
use_repetition_ids=use_repetition_ids,
)
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)

def __pow__(self, power: IntParam) -> cirq.CircuitOperation:
return self.repeat(power)
Expand Down
69 changes: 31 additions & 38 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,15 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
op_with_reps: Optional[cirq.CircuitOperation] = None
rep_ids = []
if use_default_ids_for_initial_rep:
op_with_reps = op_base.repeat(initial_repetitions)
rep_ids = ['0', '1', '2']
op_with_reps = op_base.repeat(initial_repetitions, use_repetition_ids=True)
assert op_base**initial_repetitions == op_with_reps
else:
rep_ids = ['a', 'b', 'c']
op_with_reps = op_base.repeat(initial_repetitions, rep_ids)
assert op_base**initial_repetitions != op_with_reps
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
assert op_base**initial_repetitions != op_with_reps
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
assert op_with_reps.repetitions == initial_repetitions
assert op_with_reps.use_repetition_ids
assert op_with_reps.repetition_ids == rep_ids
assert op_with_reps.repeat(1) is op_with_reps

Expand Down Expand Up @@ -332,6 +332,8 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
assert op_base.repeat(2.99999999999).repetitions == 3


# TODO: #7232 - enable and fix immediately after the 1.5.0 release
@pytest.mark.xfail(reason='broken by rollback of use_repetition_ids for #7232')
def test_replace_repetition_ids() -> None:
a, b = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b), cirq.M(b, key='mb'), cirq.M(a, key='ma'))
Expand Down Expand Up @@ -458,7 +460,6 @@ def test_parameterized_repeat_side_effects():
op = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
repetitions=sympy.Symbol('a'),
use_repetition_ids=True,
)

# Control keys can be calculated because they only "lift" if there's a matching
Expand Down Expand Up @@ -712,6 +713,7 @@ def test_string_format():
),
),
]),
use_repetition_ids=False,
)"""
)
op7 = cirq.CircuitOperation(
Expand All @@ -728,6 +730,7 @@ def test_string_format():
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
)"""
)
Expand Down Expand Up @@ -758,7 +761,6 @@ def test_json_dict():
'param_resolver': op.param_resolver,
'parent_path': op.parent_path,
'repetition_ids': None,
'use_repetition_ids': False,
}


Expand Down Expand Up @@ -865,26 +867,6 @@ def test_decompose_loops_with_measurements():
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
base_op = cirq.CircuitOperation(circuit)

op = base_op.with_qubits(b, a).repeat(3)
expected_circuit = cirq.Circuit(
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
cirq.H(b),
cirq.CX(b, a),
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
)
assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit


def test_decompose_loops_with_measurements_use_rep_ids():
a, b = cirq.LineQubit.range(2)
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
base_op = cirq.CircuitOperation(circuit, use_repetition_ids=True)

op = base_op.with_qubits(b, a).repeat(3)
expected_circuit = cirq.Circuit(
cirq.H(b),
Expand Down Expand Up @@ -1041,9 +1023,7 @@ def test_keys_under_parent_path():
op3 = cirq.with_key_path_prefix(op2, ('C',))
assert cirq.measurement_key_names(op3) == {'C:B:A'}
op4 = op3.repeat(2)
assert cirq.measurement_key_names(op4) == {'C:B:A'}
op4_rep = op3.repeat(2).replace(use_repetition_ids=True)
assert cirq.measurement_key_names(op4_rep) == {'C:B:0:A', 'C:B:1:A'}
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}


def test_mapped_circuit_preserves_moments():
Expand Down Expand Up @@ -1121,8 +1101,12 @@ def test_mapped_circuit_allows_repeated_keys():
def test_simulate_no_repetition_ids_both_levels(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
)
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=False
)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['a'].shape == (1, 4, 1)
Expand All @@ -1132,10 +1116,10 @@ def test_simulate_no_repetition_ids_both_levels(sim):
def test_simulate_no_repetition_ids_outer(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=True)
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=False
)
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['0:a'].shape == (1, 2, 1)
Expand All @@ -1146,10 +1130,10 @@ def test_simulate_no_repetition_ids_outer(sim):
def test_simulate_no_repetition_ids_inner(sim):
q = cirq.LineQubit(0)
inner = cirq.Circuit(cirq.measure(q, key='a'))
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
outer_subcircuit = cirq.CircuitOperation(
middle.freeze(), repetitions=2, use_repetition_ids=True
middle = cirq.Circuit(
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
)
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
circuit = cirq.Circuit(outer_subcircuit)
result = sim.run(circuit)
assert result.records['0:a'].shape == (1, 2, 1)
Expand All @@ -1164,6 +1148,7 @@ def test_repeat_until(sim):
cirq.X(q),
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q), cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
Expand All @@ -1178,6 +1163,7 @@ def test_repeat_until_sympy(sim):
q1, q2 = cirq.LineQubit.range(2)
circuitop = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q2), cirq.measure(q2, key='b')),
use_repetition_ids=False,
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
)
c = cirq.Circuit(cirq.measure(q1, key='a'), circuitop)
Expand All @@ -1197,6 +1183,7 @@ def test_post_selection(sim):
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
)
)
Expand All @@ -1212,13 +1199,14 @@ def test_repeat_until_diagram():
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
)
)
cirq.testing.assert_has_diagram(
c,
"""
0: ───[ 0: ───X^0.2───M('m')─── ](until=m)───
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
""",
use_unicode_characters=True,
)
Expand All @@ -1235,6 +1223,7 @@ def test_repeat_until_error():
with pytest.raises(ValueError, match='Infinite loop'):
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(q, key='m')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)

Expand All @@ -1244,6 +1233,8 @@ def test_repeat_until_protocols():
op = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.H(q) ** sympy.Symbol('p'), cirq.measure(q, key='a')),
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), 0)),
# TODO: #7232 - remove immediately after the 1.5.0 release
use_repetition_ids=False,
)
scoped = cirq.with_rescoped_keys(op, ('0',))
# Ensure the _repeat_until has been mapped, the measurement has been mapped to the same key,
Expand Down Expand Up @@ -1276,6 +1267,8 @@ def test_inner_repeat_until_simulate():
inner_loop = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.H(q), cirq.measure(q, key="inner_loop")),
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol("inner_loop"), 0)),
# TODO: #7232 - remove immediately after the 1.5.0 release
use_repetition_ids=False,
)
outer_loop = cirq.Circuit(inner_loop, cirq.X(q), cirq.measure(q, key="outer_loop"))
circuit = cirq.Circuit(
Expand Down
Loading