diff --git a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py index f96f946009f..b88b5283c5f 100644 --- a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py +++ b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py @@ -58,7 +58,7 @@ def find_terminal_measurements( return terminal_measurements -@transformer_api.transformer +@transformer_api.transformer(add_deep_support=True) def synchronize_terminal_measurements( circuit: 'cirq.AbstractCircuit', *, diff --git a/cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py b/cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py index 41097af07ea..76a7795c512 100644 --- a/cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py +++ b/cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py @@ -18,17 +18,35 @@ def assert_optimizes(before, after, measure_only_moment=True, with_context=False): - transformed_circuit = ( - cirq.synchronize_terminal_measurements(before, after_other_operations=measure_only_moment) - if not with_context - else cirq.synchronize_terminal_measurements( - before, - context=cirq.TransformerContext(tags_to_ignore=(NO_COMPILE_TAG,)), - after_other_operations=measure_only_moment, - ) + context = cirq.TransformerContext(tags_to_ignore=(NO_COMPILE_TAG,)) if with_context else None + transformed_circuit = cirq.synchronize_terminal_measurements( + before, context=context, after_other_operations=measure_only_moment ) cirq.testing.assert_same_circuits(transformed_circuit, after) + # Test nested circuit ops. + context = cirq.TransformerContext( + tags_to_ignore=("ignore",) + tuple([NO_COMPILE_TAG] if with_context else []), deep=True + ) + c_nested = cirq.Circuit( + before, + cirq.CircuitOperation(before.freeze()).repeat(5).with_tags("ignore"), + before, + cirq.CircuitOperation(before.freeze()).repeat(6).with_tags("preserve_tag"), + before, + ) + c_expected = cirq.Circuit( + before, + cirq.CircuitOperation(before.freeze()).repeat(5).with_tags("ignore"), + before, + cirq.CircuitOperation(after.freeze()).repeat(6).with_tags("preserve_tag"), + after, + ) + transformed_circuit = cirq.synchronize_terminal_measurements( + c_nested, context=context, after_other_operations=measure_only_moment + ) + cirq.testing.assert_same_circuits(transformed_circuit, c_expected) + def test_no_move(): q1 = cirq.NamedQubit('q1')