Skip to content

Commit e6c1101

Browse files
authored
Fix logic for handling repetition_ids in CircuitOperation.replace (#6984)
* Fix logic for handling `repetition_ids` in `CircuitOperation.replace` Previously, passing `repetition_ids` to `replace` would enable `use_repetition_ids`, even if the value passed was `None`. This fixes things so we only enable `use_repetition_ids` if a non-None value is passed, which is the same as what is done in the `CircuitOperation` constructor. * fmt
1 parent 116cf6e commit e6c1101

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,9 @@ def replace(self, **changes) -> 'cirq.CircuitOperation':
266266
'repetition_ids': self.repetition_ids,
267267
'parent_path': self.parent_path,
268268
'extern_keys': self._extern_keys,
269-
'use_repetition_ids': True if 'repetition_ids' in changes else self.use_repetition_ids,
269+
'use_repetition_ids': (
270+
True if changes.get('repetition_ids') is not None else self.use_repetition_ids
271+
),
270272
'repeat_until': self.repeat_until,
271273
**changes,
272274
}

cirq-core/cirq/circuits/circuit_operation_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,28 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
332332
assert op_base.repeat(2.99999999999).repetitions == 3
333333

334334

335+
def test_replace_repetition_ids() -> None:
336+
a, b = cirq.LineQubit.range(2)
337+
circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b), cirq.M(b, key='mb'), cirq.M(a, key='ma'))
338+
op = cirq.CircuitOperation(circuit.freeze())
339+
assert op.repetitions == 1
340+
assert not op.use_repetition_ids
341+
342+
op2 = op.replace(repetitions=2)
343+
assert op2.repetitions == 2
344+
assert not op2.use_repetition_ids
345+
346+
op3 = op.replace(repetitions=3, repetition_ids=None)
347+
assert op3.repetitions == 3
348+
assert not op3.use_repetition_ids
349+
350+
# Passing `repetition_ids` will also enable `use_repetition_ids`
351+
op4 = op.replace(repetitions=4, repetition_ids=['a', 'b', 'c', 'd'])
352+
assert op4.repetitions == 4
353+
assert op4.use_repetition_ids
354+
assert op4.repetition_ids == ['a', 'b', 'c', 'd']
355+
356+
335357
@pytest.mark.parametrize('add_measurements', [True, False])
336358
@pytest.mark.parametrize('use_repetition_ids', [True, False])
337359
@pytest.mark.parametrize('initial_reps', [0, 1, 2, 3])

0 commit comments

Comments
 (0)