Skip to content

Commit 6bfc809

Browse files
committed
Roll forward quantumlib#6910 - CircuitOperation: change use_repetition_ids default to False
Revert "Flip back to default `use_repetition_ids=True` in CircuitOperation (quantumlib#7237)" This reverts commit 58d9619. This also finalizes quantumlib#7232
1 parent 92f71b6 commit 6bfc809

File tree

5 files changed

+152
-127
lines changed

5 files changed

+152
-127
lines changed

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from __future__ import annotations
2323

2424
import math
25-
import warnings
2625
from functools import cached_property
2726
from typing import (
2827
Any,
@@ -49,6 +48,7 @@
4948
if TYPE_CHECKING:
5049
import cirq
5150

51+
5252
INT_CLASSES = (int, np.integer)
5353
INT_TYPE = Union[int, np.integer]
5454
IntParam = Union[INT_TYPE, sympy.Expr]
@@ -123,9 +123,8 @@ def __init__(
123123
use_repetition_ids: When True, any measurement key in the subcircuit
124124
will have its path prepended with the repetition id for each
125125
repetition. When False, this will not happen and the measurement
126-
key will be repeated. The default is True, but it will be changed
127-
to False in the next release. Please pass an explicit argument
128-
``use_repetition_ids=True`` to preserve the current behavior.
126+
key will be repeated. When None, default to False unless the caller
127+
passes `repetition_ids` explicitly.
129128
repeat_until: A condition that will be tested after each iteration of
130129
the subcircuit. The subcircuit will repeat until condition returns
131130
True, but will always run at least once, and the measurement key
@@ -162,18 +161,8 @@ def __init__(
162161
self._repetitions = repetitions
163162
self._repetition_ids = None if repetition_ids is None else list(repetition_ids)
164163
if use_repetition_ids is None:
165-
if repetition_ids is None:
166-
msg = (
167-
"In cirq 1.6 the default value of `use_repetition_ids` will change to\n"
168-
"`use_repetition_ids=False`. To make this warning go away, please pass\n"
169-
"explicit `use_repetition_ids`, e.g., to preserve current behavior, use\n"
170-
"\n"
171-
" CircuitOperations(..., use_repetition_ids=True)"
172-
)
173-
warnings.warn(msg, FutureWarning)
174-
self._use_repetition_ids = True
175-
else:
176-
self._use_repetition_ids = use_repetition_ids
164+
use_repetition_ids = repetition_ids is not None
165+
self._use_repetition_ids = use_repetition_ids
177166
if isinstance(self._repetitions, float):
178167
if math.isclose(self._repetitions, round(self._repetitions)):
179168
self._repetitions = round(self._repetitions)
@@ -281,7 +270,9 @@ def replace(self, **changes) -> cirq.CircuitOperation:
281270
'repetition_ids': self.repetition_ids,
282271
'parent_path': self.parent_path,
283272
'extern_keys': self._extern_keys,
284-
'use_repetition_ids': self.use_repetition_ids,
273+
'use_repetition_ids': (
274+
True if changes.get('repetition_ids') is not None else self.use_repetition_ids
275+
),
285276
'repeat_until': self.repeat_until,
286277
**changes,
287278
}
@@ -485,11 +476,9 @@ def __repr__(self):
485476
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
486477
if self.parent_path:
487478
args += f'parent_path={proper_repr(self.parent_path)},\n'
488-
if self.repetition_ids != self._default_repetition_ids():
479+
if self.use_repetition_ids:
489480
# Default repetition_ids need not be specified.
490481
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
491-
if not self.use_repetition_ids:
492-
args += 'use_repetition_ids=False,\n'
493482
if self.repeat_until:
494483
args += f'repeat_until={self.repeat_until!r},\n'
495484
indented_args = args.replace('\n', '\n ')
@@ -514,14 +503,15 @@ def dict_str(d: Mapping) -> str:
514503
args.append(f'params={self.param_resolver.param_dict}')
515504
if self.parent_path:
516505
args.append(f'parent_path={self.parent_path}')
517-
if self.repetition_ids != self._default_repetition_ids():
518-
# Default repetition_ids need not be specified.
519-
args.append(f'repetition_ids={self.repetition_ids}')
506+
if self.use_repetition_ids:
507+
if self.repetition_ids != self._default_repetition_ids():
508+
args.append(f'repetition_ids={self.repetition_ids}')
509+
else:
510+
# Default repetition_ids need not be specified.
511+
args.append(f'loops={self.repetitions}, use_repetition_ids=True')
520512
elif self.repetitions != 1:
521-
# Only add loops if we haven't added repetition_ids.
513+
# Add loops if not using repetition_ids.
522514
args.append(f'loops={self.repetitions}')
523-
if not self.use_repetition_ids:
524-
args.append('no_rep_ids')
525515
if self.repeat_until:
526516
args.append(f'until={self.repeat_until}')
527517
if not args:
@@ -566,10 +556,9 @@ def _json_dict_(self):
566556
'measurement_key_map': self.measurement_key_map,
567557
'param_resolver': self.param_resolver,
568558
'repetition_ids': self.repetition_ids,
559+
'use_repetition_ids': self.use_repetition_ids,
569560
'parent_path': self.parent_path,
570561
}
571-
if not self.use_repetition_ids:
572-
resp['use_repetition_ids'] = False
573562
if self.repeat_until:
574563
resp['repeat_until'] = self.repeat_until
575564
return resp
@@ -603,7 +592,10 @@ def _from_json_dict_(
603592
# Methods for constructing a similar object with one field modified.
604593

605594
def repeat(
606-
self, repetitions: Optional[IntParam] = None, repetition_ids: Optional[Sequence[str]] = None
595+
self,
596+
repetitions: Optional[IntParam] = None,
597+
repetition_ids: Optional[Sequence[str]] = None,
598+
use_repetition_ids: Optional[bool] = None,
607599
) -> CircuitOperation:
608600
"""Returns a copy of this operation repeated 'repetitions' times.
609601
Each repetition instance will be identified by a single repetition_id.
@@ -614,6 +606,10 @@ def repeat(
614606
defaults to the length of `repetition_ids`.
615607
repetition_ids: List of IDs, one for each repetition. If unset,
616608
defaults to `default_repetition_ids(repetitions)`.
609+
use_repetition_ids: If given, this specifies the value for `use_repetition_ids`
610+
of the resulting circuit operation. If not given, we enable ids if
611+
`repetition_ids` is not None, and otherwise fall back to
612+
`self.use_repetition_ids`.
617613
618614
Returns:
619615
A copy of this operation repeated `repetitions` times with the
@@ -628,6 +624,9 @@ def repeat(
628624
ValueError: Unexpected length of `repetition_ids`.
629625
ValueError: Both `repetitions` and `repetition_ids` are None.
630626
"""
627+
if use_repetition_ids is None:
628+
use_repetition_ids = True if repetition_ids is not None else self.use_repetition_ids
629+
631630
if repetitions is None:
632631
if repetition_ids is None:
633632
raise ValueError('At least one of repetitions and repetition_ids must be set')
@@ -641,7 +640,7 @@ def repeat(
641640
expected_repetition_id_length: int = np.abs(repetitions)
642641

643642
if repetition_ids is None:
644-
if self.use_repetition_ids:
643+
if use_repetition_ids:
645644
repetition_ids = default_repetition_ids(expected_repetition_id_length)
646645
elif len(repetition_ids) != expected_repetition_id_length:
647646
raise ValueError(
@@ -654,7 +653,11 @@ def repeat(
654653

655654
# The eventual number of repetitions of the returned CircuitOperation.
656655
final_repetitions = protocols.mul(self.repetitions, repetitions)
657-
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)
656+
return self.replace(
657+
repetitions=final_repetitions,
658+
repetition_ids=repetition_ids,
659+
use_repetition_ids=use_repetition_ids,
660+
)
658661

659662
def __pow__(self, power: IntParam) -> cirq.CircuitOperation:
660663
return self.repeat(power)

cirq-core/cirq/circuits/circuit_operation_test.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -294,15 +294,15 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
294294
op_with_reps: Optional[cirq.CircuitOperation] = None
295295
rep_ids = []
296296
if use_default_ids_for_initial_rep:
297-
op_with_reps = op_base.repeat(initial_repetitions)
298297
rep_ids = ['0', '1', '2']
299-
assert op_base**initial_repetitions == op_with_reps
298+
op_with_reps = op_base.repeat(initial_repetitions, use_repetition_ids=True)
300299
else:
301300
rep_ids = ['a', 'b', 'c']
302301
op_with_reps = op_base.repeat(initial_repetitions, rep_ids)
303-
assert op_base**initial_repetitions != op_with_reps
304-
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
302+
assert op_base**initial_repetitions != op_with_reps
303+
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
305304
assert op_with_reps.repetitions == initial_repetitions
305+
assert op_with_reps.use_repetition_ids
306306
assert op_with_reps.repetition_ids == rep_ids
307307
assert op_with_reps.repeat(1) is op_with_reps
308308

@@ -332,8 +332,6 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
332332
assert op_base.repeat(2.99999999999).repetitions == 3
333333

334334

335-
# TODO: #7232 - enable and fix immediately after the 1.5.0 release
336-
@pytest.mark.xfail(reason='broken by rollback of use_repetition_ids for #7232')
337335
def test_replace_repetition_ids() -> None:
338336
a, b = cirq.LineQubit.range(2)
339337
circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b), cirq.M(b, key='mb'), cirq.M(a, key='ma'))
@@ -460,6 +458,7 @@ def test_parameterized_repeat_side_effects():
460458
op = cirq.CircuitOperation(
461459
cirq.FrozenCircuit(cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
462460
repetitions=sympy.Symbol('a'),
461+
use_repetition_ids=True,
463462
)
464463

465464
# Control keys can be calculated because they only "lift" if there's a matching
@@ -713,7 +712,6 @@ def test_string_format():
713712
),
714713
),
715714
]),
716-
use_repetition_ids=False,
717715
)"""
718716
)
719717
op7 = cirq.CircuitOperation(
@@ -730,7 +728,6 @@ def test_string_format():
730728
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
731729
),
732730
]),
733-
use_repetition_ids=False,
734731
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
735732
)"""
736733
)
@@ -761,6 +758,7 @@ def test_json_dict():
761758
'param_resolver': op.param_resolver,
762759
'parent_path': op.parent_path,
763760
'repetition_ids': None,
761+
'use_repetition_ids': False,
764762
}
765763

766764

@@ -867,6 +865,26 @@ def test_decompose_loops_with_measurements():
867865
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
868866
base_op = cirq.CircuitOperation(circuit)
869867

868+
op = base_op.with_qubits(b, a).repeat(3)
869+
expected_circuit = cirq.Circuit(
870+
cirq.H(b),
871+
cirq.CX(b, a),
872+
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
873+
cirq.H(b),
874+
cirq.CX(b, a),
875+
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
876+
cirq.H(b),
877+
cirq.CX(b, a),
878+
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
879+
)
880+
assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit
881+
882+
883+
def test_decompose_loops_with_measurements_use_rep_ids():
884+
a, b = cirq.LineQubit.range(2)
885+
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
886+
base_op = cirq.CircuitOperation(circuit, use_repetition_ids=True)
887+
870888
op = base_op.with_qubits(b, a).repeat(3)
871889
expected_circuit = cirq.Circuit(
872890
cirq.H(b),
@@ -1023,7 +1041,9 @@ def test_keys_under_parent_path():
10231041
op3 = cirq.with_key_path_prefix(op2, ('C',))
10241042
assert cirq.measurement_key_names(op3) == {'C:B:A'}
10251043
op4 = op3.repeat(2)
1026-
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}
1044+
assert cirq.measurement_key_names(op4) == {'C:B:A'}
1045+
op4_rep = op3.repeat(2).replace(use_repetition_ids=True)
1046+
assert cirq.measurement_key_names(op4_rep) == {'C:B:0:A', 'C:B:1:A'}
10271047

10281048

10291049
def test_mapped_circuit_preserves_moments():
@@ -1101,12 +1121,8 @@ def test_mapped_circuit_allows_repeated_keys():
11011121
def test_simulate_no_repetition_ids_both_levels(sim):
11021122
q = cirq.LineQubit(0)
11031123
inner = cirq.Circuit(cirq.measure(q, key='a'))
1104-
middle = cirq.Circuit(
1105-
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
1106-
)
1107-
outer_subcircuit = cirq.CircuitOperation(
1108-
middle.freeze(), repetitions=2, use_repetition_ids=False
1109-
)
1124+
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1125+
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
11101126
circuit = cirq.Circuit(outer_subcircuit)
11111127
result = sim.run(circuit)
11121128
assert result.records['a'].shape == (1, 4, 1)
@@ -1116,10 +1132,10 @@ def test_simulate_no_repetition_ids_both_levels(sim):
11161132
def test_simulate_no_repetition_ids_outer(sim):
11171133
q = cirq.LineQubit(0)
11181134
inner = cirq.Circuit(cirq.measure(q, key='a'))
1119-
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1120-
outer_subcircuit = cirq.CircuitOperation(
1121-
middle.freeze(), repetitions=2, use_repetition_ids=False
1135+
middle = cirq.Circuit(
1136+
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=True)
11221137
)
1138+
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
11231139
circuit = cirq.Circuit(outer_subcircuit)
11241140
result = sim.run(circuit)
11251141
assert result.records['0:a'].shape == (1, 2, 1)
@@ -1130,10 +1146,10 @@ def test_simulate_no_repetition_ids_outer(sim):
11301146
def test_simulate_no_repetition_ids_inner(sim):
11311147
q = cirq.LineQubit(0)
11321148
inner = cirq.Circuit(cirq.measure(q, key='a'))
1133-
middle = cirq.Circuit(
1134-
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
1149+
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1150+
outer_subcircuit = cirq.CircuitOperation(
1151+
middle.freeze(), repetitions=2, use_repetition_ids=True
11351152
)
1136-
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
11371153
circuit = cirq.Circuit(outer_subcircuit)
11381154
result = sim.run(circuit)
11391155
assert result.records['0:a'].shape == (1, 2, 1)
@@ -1148,7 +1164,6 @@ def test_repeat_until(sim):
11481164
cirq.X(q),
11491165
cirq.CircuitOperation(
11501166
cirq.FrozenCircuit(cirq.X(q), cirq.measure(q, key=key)),
1151-
use_repetition_ids=False,
11521167
repeat_until=cirq.KeyCondition(key),
11531168
),
11541169
)
@@ -1163,7 +1178,6 @@ def test_repeat_until_sympy(sim):
11631178
q1, q2 = cirq.LineQubit.range(2)
11641179
circuitop = cirq.CircuitOperation(
11651180
cirq.FrozenCircuit(cirq.X(q2), cirq.measure(q2, key='b')),
1166-
use_repetition_ids=False,
11671181
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
11681182
)
11691183
c = cirq.Circuit(cirq.measure(q1, key='a'), circuitop)
@@ -1183,7 +1197,6 @@ def test_post_selection(sim):
11831197
c = cirq.Circuit(
11841198
cirq.CircuitOperation(
11851199
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
1186-
use_repetition_ids=False,
11871200
repeat_until=cirq.KeyCondition(key),
11881201
)
11891202
)
@@ -1199,14 +1212,13 @@ def test_repeat_until_diagram():
11991212
c = cirq.Circuit(
12001213
cirq.CircuitOperation(
12011214
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
1202-
use_repetition_ids=False,
12031215
repeat_until=cirq.KeyCondition(key),
12041216
)
12051217
)
12061218
cirq.testing.assert_has_diagram(
12071219
c,
12081220
"""
1209-
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
1221+
0: ───[ 0: ───X^0.2───M('m')─── ](until=m)───
12101222
""",
12111223
use_unicode_characters=True,
12121224
)
@@ -1223,7 +1235,6 @@ def test_repeat_until_error():
12231235
with pytest.raises(ValueError, match='Infinite loop'):
12241236
cirq.CircuitOperation(
12251237
cirq.FrozenCircuit(cirq.measure(q, key='m')),
1226-
use_repetition_ids=False,
12271238
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
12281239
)
12291240

@@ -1233,8 +1244,6 @@ def test_repeat_until_protocols():
12331244
op = cirq.CircuitOperation(
12341245
cirq.FrozenCircuit(cirq.H(q) ** sympy.Symbol('p'), cirq.measure(q, key='a')),
12351246
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), 0)),
1236-
# TODO: #7232 - remove immediately after the 1.5.0 release
1237-
use_repetition_ids=False,
12381247
)
12391248
scoped = cirq.with_rescoped_keys(op, ('0',))
12401249
# Ensure the _repeat_until has been mapped, the measurement has been mapped to the same key,
@@ -1267,8 +1276,6 @@ def test_inner_repeat_until_simulate():
12671276
inner_loop = cirq.CircuitOperation(
12681277
cirq.FrozenCircuit(cirq.H(q), cirq.measure(q, key="inner_loop")),
12691278
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol("inner_loop"), 0)),
1270-
# TODO: #7232 - remove immediately after the 1.5.0 release
1271-
use_repetition_ids=False,
12721279
)
12731280
outer_loop = cirq.Circuit(inner_loop, cirq.X(q), cirq.measure(q, key="outer_loop"))
12741281
circuit = cirq.Circuit(

0 commit comments

Comments
 (0)