Skip to content

Commit 5ffb3ad

Browse files
authored
CircuitOperation: change use_repetition_ids default to False (#6910)
Review: @pavoljuhas
1 parent b840178 commit 5ffb3ad

File tree

5 files changed

+150
-107
lines changed

5 files changed

+150
-107
lines changed

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
repetition_ids: Optional[Sequence[str]] = None,
9090
parent_path: Tuple[str, ...] = (),
9191
extern_keys: FrozenSet['cirq.MeasurementKey'] = frozenset(),
92-
use_repetition_ids: bool = True,
92+
use_repetition_ids: Optional[bool] = None,
9393
repeat_until: Optional['cirq.Condition'] = None,
9494
):
9595
"""Initializes a CircuitOperation.
@@ -120,7 +120,8 @@ def __init__(
120120
use_repetition_ids: When True, any measurement key in the subcircuit
121121
will have its path prepended with the repetition id for each
122122
repetition. When False, this will not happen and the measurement
123-
key will be repeated.
123+
key will be repeated. When None, default to False unless the caller
124+
passes `repetition_ids` explicitly.
124125
repeat_until: A condition that will be tested after each iteration of
125126
the subcircuit. The subcircuit will repeat until condition returns
126127
True, but will always run at least once, and the measurement key
@@ -156,6 +157,8 @@ def __init__(
156157
# Ensure that the circuit is invertible if the repetitions are negative.
157158
self._repetitions = repetitions
158159
self._repetition_ids = None if repetition_ids is None else list(repetition_ids)
160+
if use_repetition_ids is None:
161+
use_repetition_ids = repetition_ids is not None
159162
self._use_repetition_ids = use_repetition_ids
160163
if isinstance(self._repetitions, float):
161164
if math.isclose(self._repetitions, round(self._repetitions)):
@@ -263,7 +266,7 @@ def replace(self, **changes) -> 'cirq.CircuitOperation':
263266
'repetition_ids': self.repetition_ids,
264267
'parent_path': self.parent_path,
265268
'extern_keys': self._extern_keys,
266-
'use_repetition_ids': self.use_repetition_ids,
269+
'use_repetition_ids': True if 'repetition_ids' in changes else self.use_repetition_ids,
267270
'repeat_until': self.repeat_until,
268271
**changes,
269272
}
@@ -448,11 +451,9 @@ def __repr__(self):
448451
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
449452
if self.parent_path:
450453
args += f'parent_path={proper_repr(self.parent_path)},\n'
451-
if self.repetition_ids != self._default_repetition_ids():
454+
if self.use_repetition_ids:
452455
# Default repetition_ids need not be specified.
453456
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
454-
if not self.use_repetition_ids:
455-
args += 'use_repetition_ids=False,\n'
456457
if self.repeat_until:
457458
args += f'repeat_until={self.repeat_until!r},\n'
458459
indented_args = args.replace('\n', '\n ')
@@ -477,14 +478,15 @@ def dict_str(d: Mapping) -> str:
477478
args.append(f'params={self.param_resolver.param_dict}')
478479
if self.parent_path:
479480
args.append(f'parent_path={self.parent_path}')
480-
if self.repetition_ids != self._default_repetition_ids():
481-
# Default repetition_ids need not be specified.
482-
args.append(f'repetition_ids={self.repetition_ids}')
481+
if self.use_repetition_ids:
482+
if self.repetition_ids != self._default_repetition_ids():
483+
args.append(f'repetition_ids={self.repetition_ids}')
484+
else:
485+
# Default repetition_ids need not be specified.
486+
args.append(f'loops={self.repetitions}, use_repetition_ids=True')
483487
elif self.repetitions != 1:
484-
# Only add loops if we haven't added repetition_ids.
488+
# Add loops if not using repetition_ids.
485489
args.append(f'loops={self.repetitions}')
486-
if not self.use_repetition_ids:
487-
args.append('no_rep_ids')
488490
if self.repeat_until:
489491
args.append(f'until={self.repeat_until}')
490492
if not args:
@@ -529,10 +531,9 @@ def _json_dict_(self):
529531
'measurement_key_map': self.measurement_key_map,
530532
'param_resolver': self.param_resolver,
531533
'repetition_ids': self.repetition_ids,
534+
'use_repetition_ids': self.use_repetition_ids,
532535
'parent_path': self.parent_path,
533536
}
534-
if not self.use_repetition_ids:
535-
resp['use_repetition_ids'] = False
536537
if self.repeat_until:
537538
resp['repeat_until'] = self.repeat_until
538539
return resp
@@ -566,7 +567,10 @@ def _from_json_dict_(
566567
# Methods for constructing a similar object with one field modified.
567568

568569
def repeat(
569-
self, repetitions: Optional[IntParam] = None, repetition_ids: Optional[Sequence[str]] = None
570+
self,
571+
repetitions: Optional[IntParam] = None,
572+
repetition_ids: Optional[Sequence[str]] = None,
573+
use_repetition_ids: Optional[bool] = None,
570574
) -> 'CircuitOperation':
571575
"""Returns a copy of this operation repeated 'repetitions' times.
572576
Each repetition instance will be identified by a single repetition_id.
@@ -577,6 +581,10 @@ def repeat(
577581
defaults to the length of `repetition_ids`.
578582
repetition_ids: List of IDs, one for each repetition. If unset,
579583
defaults to `default_repetition_ids(repetitions)`.
584+
use_repetition_ids: If given, this specifies the value for `use_repetition_ids`
585+
of the resulting circuit operation. If not given, we enable ids if
586+
`repetition_ids` is not None, and otherwise fall back to
587+
`self.use_repetition_ids`.
580588
581589
Returns:
582590
A copy of this operation repeated `repetitions` times with the
@@ -591,6 +599,9 @@ def repeat(
591599
ValueError: Unexpected length of `repetition_ids`.
592600
ValueError: Both `repetitions` and `repetition_ids` are None.
593601
"""
602+
if use_repetition_ids is None:
603+
use_repetition_ids = True if repetition_ids is not None else self.use_repetition_ids
604+
594605
if repetitions is None:
595606
if repetition_ids is None:
596607
raise ValueError('At least one of repetitions and repetition_ids must be set')
@@ -604,7 +615,7 @@ def repeat(
604615
expected_repetition_id_length: int = np.abs(repetitions)
605616

606617
if repetition_ids is None:
607-
if self.use_repetition_ids:
618+
if use_repetition_ids:
608619
repetition_ids = default_repetition_ids(expected_repetition_id_length)
609620
elif len(repetition_ids) != expected_repetition_id_length:
610621
raise ValueError(
@@ -617,7 +628,11 @@ def repeat(
617628

618629
# The eventual number of repetitions of the returned CircuitOperation.
619630
final_repetitions = protocols.mul(self.repetitions, repetitions)
620-
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)
631+
return self.replace(
632+
repetitions=final_repetitions,
633+
repetition_ids=repetition_ids,
634+
use_repetition_ids=use_repetition_ids,
635+
)
621636

622637
def __pow__(self, power: IntParam) -> 'cirq.CircuitOperation':
623638
return self.repeat(power)

cirq-core/cirq/circuits/circuit_operation_test.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -294,15 +294,15 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
294294
op_with_reps: Optional[cirq.CircuitOperation] = None
295295
rep_ids = []
296296
if use_default_ids_for_initial_rep:
297-
op_with_reps = op_base.repeat(initial_repetitions)
298297
rep_ids = ['0', '1', '2']
299-
assert op_base**initial_repetitions == op_with_reps
298+
op_with_reps = op_base.repeat(initial_repetitions, use_repetition_ids=True)
300299
else:
301300
rep_ids = ['a', 'b', 'c']
302301
op_with_reps = op_base.repeat(initial_repetitions, rep_ids)
303-
assert op_base**initial_repetitions != op_with_reps
304-
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
302+
assert op_base**initial_repetitions != op_with_reps
303+
assert (op_base**initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
305304
assert op_with_reps.repetitions == initial_repetitions
305+
assert op_with_reps.use_repetition_ids
306306
assert op_with_reps.repetition_ids == rep_ids
307307
assert op_with_reps.repeat(1) is op_with_reps
308308

@@ -436,6 +436,7 @@ def test_parameterized_repeat_side_effects():
436436
op = cirq.CircuitOperation(
437437
cirq.FrozenCircuit(cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
438438
repetitions=sympy.Symbol('a'),
439+
use_repetition_ids=True,
439440
)
440441

441442
# Control keys can be calculated because they only "lift" if there's a matching
@@ -689,7 +690,6 @@ def test_string_format():
689690
),
690691
),
691692
]),
692-
use_repetition_ids=False,
693693
)"""
694694
)
695695
op7 = cirq.CircuitOperation(
@@ -706,7 +706,6 @@ def test_string_format():
706706
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
707707
),
708708
]),
709-
use_repetition_ids=False,
710709
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
711710
)"""
712711
)
@@ -737,6 +736,7 @@ def test_json_dict():
737736
'param_resolver': op.param_resolver,
738737
'parent_path': op.parent_path,
739738
'repetition_ids': None,
739+
'use_repetition_ids': False,
740740
}
741741

742742

@@ -843,6 +843,26 @@ def test_decompose_loops_with_measurements():
843843
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
844844
base_op = cirq.CircuitOperation(circuit)
845845

846+
op = base_op.with_qubits(b, a).repeat(3)
847+
expected_circuit = cirq.Circuit(
848+
cirq.H(b),
849+
cirq.CX(b, a),
850+
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
851+
cirq.H(b),
852+
cirq.CX(b, a),
853+
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
854+
cirq.H(b),
855+
cirq.CX(b, a),
856+
cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('m')),
857+
)
858+
assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit
859+
860+
861+
def test_decompose_loops_with_measurements_use_rep_ids():
862+
a, b = cirq.LineQubit.range(2)
863+
circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b), cirq.measure(a, b, key='m'))
864+
base_op = cirq.CircuitOperation(circuit, use_repetition_ids=True)
865+
846866
op = base_op.with_qubits(b, a).repeat(3)
847867
expected_circuit = cirq.Circuit(
848868
cirq.H(b),
@@ -999,7 +1019,9 @@ def test_keys_under_parent_path():
9991019
op3 = cirq.with_key_path_prefix(op2, ('C',))
10001020
assert cirq.measurement_key_names(op3) == {'C:B:A'}
10011021
op4 = op3.repeat(2)
1002-
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}
1022+
assert cirq.measurement_key_names(op4) == {'C:B:A'}
1023+
op4_rep = op3.repeat(2).replace(use_repetition_ids=True)
1024+
assert cirq.measurement_key_names(op4_rep) == {'C:B:0:A', 'C:B:1:A'}
10031025

10041026

10051027
def test_mapped_circuit_preserves_moments():
@@ -1077,12 +1099,8 @@ def test_mapped_circuit_allows_repeated_keys():
10771099
def test_simulate_no_repetition_ids_both_levels(sim):
10781100
q = cirq.LineQubit(0)
10791101
inner = cirq.Circuit(cirq.measure(q, key='a'))
1080-
middle = cirq.Circuit(
1081-
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
1082-
)
1083-
outer_subcircuit = cirq.CircuitOperation(
1084-
middle.freeze(), repetitions=2, use_repetition_ids=False
1085-
)
1102+
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1103+
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
10861104
circuit = cirq.Circuit(outer_subcircuit)
10871105
result = sim.run(circuit)
10881106
assert result.records['a'].shape == (1, 4, 1)
@@ -1092,10 +1110,10 @@ def test_simulate_no_repetition_ids_both_levels(sim):
10921110
def test_simulate_no_repetition_ids_outer(sim):
10931111
q = cirq.LineQubit(0)
10941112
inner = cirq.Circuit(cirq.measure(q, key='a'))
1095-
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1096-
outer_subcircuit = cirq.CircuitOperation(
1097-
middle.freeze(), repetitions=2, use_repetition_ids=False
1113+
middle = cirq.Circuit(
1114+
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=True)
10981115
)
1116+
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
10991117
circuit = cirq.Circuit(outer_subcircuit)
11001118
result = sim.run(circuit)
11011119
assert result.records['0:a'].shape == (1, 2, 1)
@@ -1106,10 +1124,10 @@ def test_simulate_no_repetition_ids_outer(sim):
11061124
def test_simulate_no_repetition_ids_inner(sim):
11071125
q = cirq.LineQubit(0)
11081126
inner = cirq.Circuit(cirq.measure(q, key='a'))
1109-
middle = cirq.Circuit(
1110-
cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False)
1127+
middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2))
1128+
outer_subcircuit = cirq.CircuitOperation(
1129+
middle.freeze(), repetitions=2, use_repetition_ids=True
11111130
)
1112-
outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2)
11131131
circuit = cirq.Circuit(outer_subcircuit)
11141132
result = sim.run(circuit)
11151133
assert result.records['0:a'].shape == (1, 2, 1)
@@ -1124,7 +1142,6 @@ def test_repeat_until(sim):
11241142
cirq.X(q),
11251143
cirq.CircuitOperation(
11261144
cirq.FrozenCircuit(cirq.X(q), cirq.measure(q, key=key)),
1127-
use_repetition_ids=False,
11281145
repeat_until=cirq.KeyCondition(key),
11291146
),
11301147
)
@@ -1139,7 +1156,6 @@ def test_repeat_until_sympy(sim):
11391156
q1, q2 = cirq.LineQubit.range(2)
11401157
circuitop = cirq.CircuitOperation(
11411158
cirq.FrozenCircuit(cirq.X(q2), cirq.measure(q2, key='b')),
1142-
use_repetition_ids=False,
11431159
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
11441160
)
11451161
c = cirq.Circuit(cirq.measure(q1, key='a'), circuitop)
@@ -1159,7 +1175,6 @@ def test_post_selection(sim):
11591175
c = cirq.Circuit(
11601176
cirq.CircuitOperation(
11611177
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
1162-
use_repetition_ids=False,
11631178
repeat_until=cirq.KeyCondition(key),
11641179
)
11651180
)
@@ -1175,14 +1190,13 @@ def test_repeat_until_diagram():
11751190
c = cirq.Circuit(
11761191
cirq.CircuitOperation(
11771192
cirq.FrozenCircuit(cirq.X(q) ** 0.2, cirq.measure(q, key=key)),
1178-
use_repetition_ids=False,
11791193
repeat_until=cirq.KeyCondition(key),
11801194
)
11811195
)
11821196
cirq.testing.assert_has_diagram(
11831197
c,
11841198
"""
1185-
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
1199+
0: ───[ 0: ───X^0.2───M('m')─── ](until=m)───
11861200
""",
11871201
use_unicode_characters=True,
11881202
)
@@ -1199,7 +1213,6 @@ def test_repeat_until_error():
11991213
with pytest.raises(ValueError, match='Infinite loop'):
12001214
cirq.CircuitOperation(
12011215
cirq.FrozenCircuit(cirq.measure(q, key='m')),
1202-
use_repetition_ids=False,
12031216
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
12041217
)
12051218

0 commit comments

Comments
 (0)