diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 18d3d035c5b..bf6239a56a1 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -60,6 +60,7 @@ def map_moments( circuit: CIRCUIT_TYPE, map_func: Callable[[circuits.Moment, int], Union[circuits.Moment, Sequence[circuits.Moment]]], *, + tags_to_ignore: Sequence[Hashable] = (), deep: bool = False, ) -> CIRCUIT_TYPE: """Applies local transformation on moments, by calling `map_func(moment)` for each moment. @@ -67,6 +68,9 @@ def map_moments( Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. map_func: Mapping function from (cirq.Moment, moment_index) to a sequence of moments. + tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be + ignored when recursively applying the transformer primitive to sub-circuits, given + deep=True. deep: If true, `map_func` will be recursively applied to circuits wrapped inside any circuit operations contained within `circuit`. @@ -79,6 +83,8 @@ def map_moments( for i, op in circuit.findall_operations( lambda o: isinstance(o.untagged, circuits.CircuitOperation) ): + if set(op.tags).intersection(tags_to_ignore): + continue op_untagged = cast(circuits.CircuitOperation, op.untagged) mapped_op = op_untagged.replace( circuit=map_moments(op_untagged.circuit, map_func, deep=deep) @@ -190,6 +196,7 @@ def merge_operations( merge_func: Callable[[ops.Operation, ops.Operation], Optional[ops.Operation]], *, tags_to_ignore: Sequence[Hashable] = (), + deep: bool = False, ) -> CIRCUIT_TYPE: """Merges operations in a circuit by calling `merge_func` iteratively on operations. @@ -226,6 +233,8 @@ def merge_operations( tags_to_ignore: Sequence of tags which should be ignored while applying `merge_func` on tagged operations -- i.e. `merge_func(op1, op2)` will be called only if both `op1` and `op2` satisfy `set(op.tags).isdisjoint(tags_to_ignore)`. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. Returns: @@ -235,9 +244,11 @@ def merge_operations( ValueError if the merged operation acts on new qubits outside the set of qubits corresponding to the original operations to be merged. """ + _circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit" + tags_to_ignore_set = set(tags_to_ignore) | {_circuit_op_tag} def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Operation]: - if not all(set(op.tags).isdisjoint(tags_to_ignore) for op in [op1, op2]): + if not all(tags_to_ignore_set.isdisjoint(op.tags) for op in [op1, op2]): return None new_op = merge_func(op1, op2) qubit_set = frozenset(op1.qubits + op2.qubits) @@ -252,6 +263,23 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope for current_moment in circuit: new_moment = circuits.Moment() for op in sorted(current_moment.operations, key=lambda op: op.qubits): + if ( + deep + and isinstance(op.untagged, circuits.CircuitOperation) + and tags_to_ignore_set.isdisjoint(op.tags) + ): + op_untagged = op.untagged + new_moment = new_moment.with_operation( + op_untagged.replace( + circuit=merge_operations( + op_untagged.circuit, + merge_func, + tags_to_ignore=tags_to_ignore, + deep=True, + ) + ).with_tags(*op.tags, _circuit_op_tag) + ) + continue op_qs = set(op.qubits) idx = ret_circuit.prev_moment_operating_on(tuple(op_qs)) if idx is not None and op_qs.issubset(ret_circuit[idx][op_qs].operations[0].qubits): @@ -279,6 +307,12 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope idx = ret_circuit.prev_moment_operating_on(tuple(op_qs)) new_moment = new_moment.with_operation(op) ret_circuit += new_moment + if deep: + ret_circuit = map_operations( + ret_circuit, + lambda o, _: o.untagged.with_tags(*(set(o.tags) - {_circuit_op_tag})), + deep=True, + ) return _to_target_circuit_type(ret_circuit, circuit) @@ -288,6 +322,7 @@ def merge_operations_to_circuit_op( *, tags_to_ignore: Sequence[Hashable] = (), merged_circuit_op_tag: str = "Merged connected component", + deep: bool = False, ) -> CIRCUIT_TYPE: """Merges connected components of operations and wraps each component into a circuit operation. @@ -307,6 +342,8 @@ def merge_operations_to_circuit_op( potential candidates for any connected component. merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected components. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. Returns: Copy of input circuit with valid connected components wrapped in tagged circuit operations. @@ -329,7 +366,7 @@ def get_ops(op: 'cirq.Operation'): merged_circuit_op_tag ) - return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore) + return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore, deep=deep) def merge_k_qubit_unitaries_to_circuit_op( @@ -338,6 +375,7 @@ def merge_k_qubit_unitaries_to_circuit_op( *, tags_to_ignore: Sequence[Hashable] = (), merged_circuit_op_tag: Optional[str] = None, + deep: bool = False, ) -> CIRCUIT_TYPE: """Merges connected components of operations, acting on <= k qubits, into circuit operations. @@ -353,6 +391,8 @@ def merge_k_qubit_unitaries_to_circuit_op( potential candidates for any connected component. merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected components. A default tag is applied if left None. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. Returns: Copy of input circuit with valid connected components wrapped in tagged circuit operations. @@ -370,12 +410,16 @@ def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation'] can_merge, tags_to_ignore=tags_to_ignore, merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.", + deep=deep, ) def merge_moments( circuit: CIRCUIT_TYPE, merge_func: Callable[[circuits.Moment, circuits.Moment], Optional[circuits.Moment]], + *, + tags_to_ignore: Sequence[Hashable] = (), + deep: bool = False, ) -> CIRCUIT_TYPE: """Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`. @@ -384,12 +428,27 @@ def merge_moments( merge_func: Callable to determine whether two adjacent moments in the circuit should be merged. If the moments can be merged, the callable should return the merged moment, else None. + tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be + ignored when recursively applying the transformer primitive to sub-circuits, given + deep=True. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. Returns: Copy of input circuit with merged moments. """ if not circuit: return circuit + if deep: + circuit = map_operations( + circuit, + lambda op, _: op.untagged.replace( + circuit=merge_moments(op.untagged.circuit, merge_func, deep=deep) + ).with_tags(*op.tags) + if isinstance(op.untagged, circuits.CircuitOperation) + else op, + tags_to_ignore=tags_to_ignore, + ) merged_moments: List[circuits.Moment] = [circuit[0]] for current_moment in circuit[1:]: merged_moment = merge_func(merged_moments[-1], current_moment) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index d00fba04c98..9ac2a87b8a1 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -399,6 +399,35 @@ def test_map_moments_drop_empty_moments(): cirq.testing.assert_same_circuits(c_mapped, cirq.Circuit(c[0], c[0])) +def test_map_moments_drop_empty_moments_deep(): + op = cirq.X(cirq.NamedQubit("q")) + c_nested = cirq.FrozenCircuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op)) + c_orig = cirq.Circuit( + c_nested, + cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"), + c_nested, + cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"), + ) + c_expected = cirq.Circuit( + [op, op], + cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"), + [op, op], + cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(5).with_tags("preserve_tag"), + ) + c_mapped = cirq.map_moments( + c_orig, lambda m, i: [] if len(m) == 0 else [m], deep=True, tags_to_ignore=("ignore",) + ) + cirq.testing.assert_same_circuits(c_mapped, c_expected) + + +def _merge_z_moments_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]: + if any(op.gate != cirq.Z for m in [m1, m2] for op in m): + return None + return cirq.Moment( + cirq.Z(q) for q in (m1.qubits | m2.qubits) if m1.operates_on([q]) ^ m2.operates_on([q]) + ) + + def test_merge_moments(): q = cirq.LineQubit.range(3) c_orig = cirq.Circuit( @@ -419,21 +448,8 @@ def test_merge_moments(): ''', ) - def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]: - def is_z_moment(m): - return all(op.gate == cirq.Z for op in m) - - if not (is_z_moment(m1) and is_z_moment(m2)): - return None - qubits = m1.qubits | m2.qubits - - def mul(op1, op2): - return (op1 or op2) if not (op1 and op2) else cirq.decompose_once(op1 * op2) - - return cirq.Moment(mul(m1.operation_at(q), m2.operation_at(q)) for q in qubits) - cirq.testing.assert_has_diagram( - cirq.merge_moments(c_orig, merge_func), + cirq.merge_moments(c_orig, _merge_z_moments_func), ''' 0: ───────@─────── │ @@ -444,6 +460,35 @@ def mul(op1, op2): ) +def test_merge_moments_deep(): + q = cirq.LineQubit.range(3) + c_z_moments = cirq.Circuit( + [cirq.Z.on_each(q[0], q[1]), cirq.Z.on_each(q[1], q[2]), cirq.Z.on_each(q[1], q[0])], + strategy=cirq.InsertStrategy.NEW_THEN_INLINE, + ) + merged_z_moment = cirq.Moment(cirq.Z.on_each(*q[1:])) + c_nested_circuit = cirq.FrozenCircuit(c_z_moments, cirq.CCX(*q), c_z_moments) + c_merged_circuit = cirq.FrozenCircuit(merged_z_moment, cirq.CCX(*q), merged_z_moment) + c_orig = cirq.Circuit( + cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"), + c_nested_circuit, + cirq.CircuitOperation(c_nested_circuit).repeat(6).with_tags("preserve_tag"), + c_nested_circuit, + cirq.CircuitOperation(c_nested_circuit).repeat(7), + ) + c_expected = cirq.Circuit( + cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"), + c_merged_circuit, + cirq.CircuitOperation(c_merged_circuit).repeat(6).with_tags("preserve_tag"), + c_merged_circuit, + cirq.CircuitOperation(c_merged_circuit).repeat(7), + ) + cirq.testing.assert_same_circuits( + cirq.merge_moments(c_orig, _merge_z_moments_func, tags_to_ignore=("ignore",), deep=True), + c_expected, + ) + + def test_merge_moments_empty_moment_as_intermediate_step(): q = cirq.NamedQubit("q") c_orig = cirq.Circuit([cirq.X(q), cirq.Y(q), cirq.Z(q)] * 2, cirq.X(q) ** 0.5) @@ -543,7 +588,45 @@ def merge_func(op1, op2): ) +def test_merge_operations_deep(): + q = cirq.LineQubit.range(2) + h_cz_y = [cirq.H(q[0]), cirq.CZ(*q), cirq.Y(q[1])] + m_cz_m = [cirq.Moment(), cirq.Moment(cirq.CZ(*q)), cirq.Moment()] + c_orig = cirq.Circuit( + h_cz_y, + cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])), + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"), + [cirq.CNOT(*q), cirq.CNOT(*q)], + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(4), + [cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)], + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(5).with_tags("preserve_tag"), + ) + c_expected = cirq.Circuit( + m_cz_m, + cirq.Moment(cirq.X(q[0]).with_tags("ignore")), + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"), + [cirq.CNOT(*q), cirq.CNOT(*q)], + cirq.CircuitOperation(cirq.FrozenCircuit(m_cz_m)).repeat(4), + [cirq.CZ(*q), cirq.Moment(), cirq.Moment()], + cirq.CircuitOperation(cirq.FrozenCircuit(m_cz_m)).repeat(5).with_tags("preserve_tag"), + strategy=cirq.InsertStrategy.NEW, + ) + + def merge_func(op1, op2): + """Artificial example where a CZ will absorb any merge-able operation.""" + for op in [op1, op2]: + if op.gate == cirq.CZ: + return op + return None + + cirq.testing.assert_same_circuits( + cirq.merge_operations(c_orig, merge_func, tags_to_ignore=["ignore"], deep=True), c_expected + ) + + # pylint: disable=line-too-long + + def test_merge_operations_to_circuit_op_merges_connected_component(): c_orig = _create_circuit_to_merge() cirq.testing.assert_has_diagram(