@@ -294,15 +294,15 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
294
294
op_with_reps : Optional [cirq .CircuitOperation ] = None
295
295
rep_ids = []
296
296
if use_default_ids_for_initial_rep :
297
- op_with_reps = op_base .repeat (initial_repetitions )
298
297
rep_ids = ['0' , '1' , '2' ]
299
- assert op_base ** initial_repetitions == op_with_reps
298
+ op_with_reps = op_base . repeat ( initial_repetitions , use_repetition_ids = True )
300
299
else :
301
300
rep_ids = ['a' , 'b' , 'c' ]
302
301
op_with_reps = op_base .repeat (initial_repetitions , rep_ids )
303
- assert op_base ** initial_repetitions != op_with_reps
304
- assert (op_base ** initial_repetitions ).replace (repetition_ids = rep_ids ) == op_with_reps
302
+ assert op_base ** initial_repetitions != op_with_reps
303
+ assert (op_base ** initial_repetitions ).replace (repetition_ids = rep_ids ) == op_with_reps
305
304
assert op_with_reps .repetitions == initial_repetitions
305
+ assert op_with_reps .use_repetition_ids
306
306
assert op_with_reps .repetition_ids == rep_ids
307
307
assert op_with_reps .repeat (1 ) is op_with_reps
308
308
@@ -436,6 +436,7 @@ def test_parameterized_repeat_side_effects():
436
436
op = cirq .CircuitOperation (
437
437
cirq .FrozenCircuit (cirq .X (q ).with_classical_controls ('c' ), cirq .measure (q , key = 'm' )),
438
438
repetitions = sympy .Symbol ('a' ),
439
+ use_repetition_ids = True ,
439
440
)
440
441
441
442
# Control keys can be calculated because they only "lift" if there's a matching
@@ -689,7 +690,6 @@ def test_string_format():
689
690
),
690
691
),
691
692
]),
692
- use_repetition_ids=False,
693
693
)"""
694
694
)
695
695
op7 = cirq .CircuitOperation (
@@ -706,7 +706,6 @@ def test_string_format():
706
706
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
707
707
),
708
708
]),
709
- use_repetition_ids=False,
710
709
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
711
710
)"""
712
711
)
@@ -737,6 +736,7 @@ def test_json_dict():
737
736
'param_resolver' : op .param_resolver ,
738
737
'parent_path' : op .parent_path ,
739
738
'repetition_ids' : None ,
739
+ 'use_repetition_ids' : False ,
740
740
}
741
741
742
742
@@ -843,6 +843,26 @@ def test_decompose_loops_with_measurements():
843
843
circuit = cirq .FrozenCircuit (cirq .H (a ), cirq .CX (a , b ), cirq .measure (a , b , key = 'm' ))
844
844
base_op = cirq .CircuitOperation (circuit )
845
845
846
+ op = base_op .with_qubits (b , a ).repeat (3 )
847
+ expected_circuit = cirq .Circuit (
848
+ cirq .H (b ),
849
+ cirq .CX (b , a ),
850
+ cirq .measure (b , a , key = cirq .MeasurementKey .parse_serialized ('m' )),
851
+ cirq .H (b ),
852
+ cirq .CX (b , a ),
853
+ cirq .measure (b , a , key = cirq .MeasurementKey .parse_serialized ('m' )),
854
+ cirq .H (b ),
855
+ cirq .CX (b , a ),
856
+ cirq .measure (b , a , key = cirq .MeasurementKey .parse_serialized ('m' )),
857
+ )
858
+ assert cirq .Circuit (cirq .decompose_once (op )) == expected_circuit
859
+
860
+
861
+ def test_decompose_loops_with_measurements_use_rep_ids ():
862
+ a , b = cirq .LineQubit .range (2 )
863
+ circuit = cirq .FrozenCircuit (cirq .H (a ), cirq .CX (a , b ), cirq .measure (a , b , key = 'm' ))
864
+ base_op = cirq .CircuitOperation (circuit , use_repetition_ids = True )
865
+
846
866
op = base_op .with_qubits (b , a ).repeat (3 )
847
867
expected_circuit = cirq .Circuit (
848
868
cirq .H (b ),
@@ -999,7 +1019,9 @@ def test_keys_under_parent_path():
999
1019
op3 = cirq .with_key_path_prefix (op2 , ('C' ,))
1000
1020
assert cirq .measurement_key_names (op3 ) == {'C:B:A' }
1001
1021
op4 = op3 .repeat (2 )
1002
- assert cirq .measurement_key_names (op4 ) == {'C:B:0:A' , 'C:B:1:A' }
1022
+ assert cirq .measurement_key_names (op4 ) == {'C:B:A' }
1023
+ op4_rep = op3 .repeat (2 ).replace (use_repetition_ids = True )
1024
+ assert cirq .measurement_key_names (op4_rep ) == {'C:B:0:A' , 'C:B:1:A' }
1003
1025
1004
1026
1005
1027
def test_mapped_circuit_preserves_moments ():
@@ -1077,12 +1099,8 @@ def test_mapped_circuit_allows_repeated_keys():
1077
1099
def test_simulate_no_repetition_ids_both_levels (sim ):
1078
1100
q = cirq .LineQubit (0 )
1079
1101
inner = cirq .Circuit (cirq .measure (q , key = 'a' ))
1080
- middle = cirq .Circuit (
1081
- cirq .CircuitOperation (inner .freeze (), repetitions = 2 , use_repetition_ids = False )
1082
- )
1083
- outer_subcircuit = cirq .CircuitOperation (
1084
- middle .freeze (), repetitions = 2 , use_repetition_ids = False
1085
- )
1102
+ middle = cirq .Circuit (cirq .CircuitOperation (inner .freeze (), repetitions = 2 ))
1103
+ outer_subcircuit = cirq .CircuitOperation (middle .freeze (), repetitions = 2 )
1086
1104
circuit = cirq .Circuit (outer_subcircuit )
1087
1105
result = sim .run (circuit )
1088
1106
assert result .records ['a' ].shape == (1 , 4 , 1 )
@@ -1092,10 +1110,10 @@ def test_simulate_no_repetition_ids_both_levels(sim):
1092
1110
def test_simulate_no_repetition_ids_outer (sim ):
1093
1111
q = cirq .LineQubit (0 )
1094
1112
inner = cirq .Circuit (cirq .measure (q , key = 'a' ))
1095
- middle = cirq .Circuit (cirq .CircuitOperation (inner .freeze (), repetitions = 2 ))
1096
- outer_subcircuit = cirq .CircuitOperation (
1097
- middle .freeze (), repetitions = 2 , use_repetition_ids = False
1113
+ middle = cirq .Circuit (
1114
+ cirq .CircuitOperation (inner .freeze (), repetitions = 2 , use_repetition_ids = True )
1098
1115
)
1116
+ outer_subcircuit = cirq .CircuitOperation (middle .freeze (), repetitions = 2 )
1099
1117
circuit = cirq .Circuit (outer_subcircuit )
1100
1118
result = sim .run (circuit )
1101
1119
assert result .records ['0:a' ].shape == (1 , 2 , 1 )
@@ -1106,10 +1124,10 @@ def test_simulate_no_repetition_ids_outer(sim):
1106
1124
def test_simulate_no_repetition_ids_inner (sim ):
1107
1125
q = cirq .LineQubit (0 )
1108
1126
inner = cirq .Circuit (cirq .measure (q , key = 'a' ))
1109
- middle = cirq .Circuit (
1110
- cirq .CircuitOperation (inner .freeze (), repetitions = 2 , use_repetition_ids = False )
1127
+ middle = cirq .Circuit (cirq .CircuitOperation (inner .freeze (), repetitions = 2 ))
1128
+ outer_subcircuit = cirq .CircuitOperation (
1129
+ middle .freeze (), repetitions = 2 , use_repetition_ids = True
1111
1130
)
1112
- outer_subcircuit = cirq .CircuitOperation (middle .freeze (), repetitions = 2 )
1113
1131
circuit = cirq .Circuit (outer_subcircuit )
1114
1132
result = sim .run (circuit )
1115
1133
assert result .records ['0:a' ].shape == (1 , 2 , 1 )
@@ -1124,7 +1142,6 @@ def test_repeat_until(sim):
1124
1142
cirq .X (q ),
1125
1143
cirq .CircuitOperation (
1126
1144
cirq .FrozenCircuit (cirq .X (q ), cirq .measure (q , key = key )),
1127
- use_repetition_ids = False ,
1128
1145
repeat_until = cirq .KeyCondition (key ),
1129
1146
),
1130
1147
)
@@ -1139,7 +1156,6 @@ def test_repeat_until_sympy(sim):
1139
1156
q1 , q2 = cirq .LineQubit .range (2 )
1140
1157
circuitop = cirq .CircuitOperation (
1141
1158
cirq .FrozenCircuit (cirq .X (q2 ), cirq .measure (q2 , key = 'b' )),
1142
- use_repetition_ids = False ,
1143
1159
repeat_until = cirq .SympyCondition (sympy .Eq (sympy .Symbol ('a' ), sympy .Symbol ('b' ))),
1144
1160
)
1145
1161
c = cirq .Circuit (cirq .measure (q1 , key = 'a' ), circuitop )
@@ -1159,7 +1175,6 @@ def test_post_selection(sim):
1159
1175
c = cirq .Circuit (
1160
1176
cirq .CircuitOperation (
1161
1177
cirq .FrozenCircuit (cirq .X (q ) ** 0.2 , cirq .measure (q , key = key )),
1162
- use_repetition_ids = False ,
1163
1178
repeat_until = cirq .KeyCondition (key ),
1164
1179
)
1165
1180
)
@@ -1175,14 +1190,13 @@ def test_repeat_until_diagram():
1175
1190
c = cirq .Circuit (
1176
1191
cirq .CircuitOperation (
1177
1192
cirq .FrozenCircuit (cirq .X (q ) ** 0.2 , cirq .measure (q , key = key )),
1178
- use_repetition_ids = False ,
1179
1193
repeat_until = cirq .KeyCondition (key ),
1180
1194
)
1181
1195
)
1182
1196
cirq .testing .assert_has_diagram (
1183
1197
c ,
1184
1198
"""
1185
- 0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
1199
+ 0: ───[ 0: ───X^0.2───M('m')─── ](until=m)───
1186
1200
""" ,
1187
1201
use_unicode_characters = True ,
1188
1202
)
@@ -1199,7 +1213,6 @@ def test_repeat_until_error():
1199
1213
with pytest .raises (ValueError , match = 'Infinite loop' ):
1200
1214
cirq .CircuitOperation (
1201
1215
cirq .FrozenCircuit (cirq .measure (q , key = 'm' )),
1202
- use_repetition_ids = False ,
1203
1216
repeat_until = cirq .KeyCondition (cirq .MeasurementKey ('a' )),
1204
1217
)
1205
1218
0 commit comments