Skip to content

Add support for deep=True flag to remaining transformer primitives #5106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@ def map_moments(
circuit: CIRCUIT_TYPE,
map_func: Callable[[circuits.Moment, int], Union[circuits.Moment, Sequence[circuits.Moment]]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Applies local transformation on moments, by calling `map_func(moment)` for each moment.

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
map_func: Mapping function from (cirq.Moment, moment_index) to a sequence of moments.
tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
ignored when recursively applying the transformer primitive to sub-circuits, given
deep=True.
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
any circuit operations contained within `circuit`.

Expand All @@ -79,6 +83,8 @@ def map_moments(
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
if set(op.tags).intersection(tags_to_ignore):
continue
op_untagged = cast(circuits.CircuitOperation, op.untagged)
mapped_op = op_untagged.replace(
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
Expand Down Expand Up @@ -190,6 +196,7 @@ def merge_operations(
merge_func: Callable[[ops.Operation, ops.Operation], Optional[ops.Operation]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges operations in a circuit by calling `merge_func` iteratively on operations.

Expand Down Expand Up @@ -226,6 +233,8 @@ def merge_operations(
tags_to_ignore: Sequence of tags which should be ignored while applying `merge_func` on
tagged operations -- i.e. `merge_func(op1, op2)` will be called only if both `op1` and
`op2` satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.


Returns:
Expand All @@ -235,9 +244,11 @@ def merge_operations(
ValueError if the merged operation acts on new qubits outside the set of qubits
corresponding to the original operations to be merged.
"""
_circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit"
tags_to_ignore_set = set(tags_to_ignore) | {_circuit_op_tag}

def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Operation]:
if not all(set(op.tags).isdisjoint(tags_to_ignore) for op in [op1, op2]):
if not all(tags_to_ignore_set.isdisjoint(op.tags) for op in [op1, op2]):
return None
new_op = merge_func(op1, op2)
qubit_set = frozenset(op1.qubits + op2.qubits)
Expand All @@ -252,6 +263,23 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
for current_moment in circuit:
new_moment = circuits.Moment()
for op in sorted(current_moment.operations, key=lambda op: op.qubits):
if (
deep
and isinstance(op.untagged, circuits.CircuitOperation)
and tags_to_ignore_set.isdisjoint(op.tags)
):
op_untagged = op.untagged
new_moment = new_moment.with_operation(
op_untagged.replace(
circuit=merge_operations(
op_untagged.circuit,
merge_func,
tags_to_ignore=tags_to_ignore,
deep=True,
)
).with_tags(*op.tags, _circuit_op_tag)
)
continue
op_qs = set(op.qubits)
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
if idx is not None and op_qs.issubset(ret_circuit[idx][op_qs].operations[0].qubits):
Expand Down Expand Up @@ -279,6 +307,12 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
new_moment = new_moment.with_operation(op)
ret_circuit += new_moment
if deep:
ret_circuit = map_operations(
ret_circuit,
lambda o, _: o.untagged.with_tags(*(set(o.tags) - {_circuit_op_tag})),
deep=True,
)
return _to_target_circuit_type(ret_circuit, circuit)


Expand All @@ -288,6 +322,7 @@ def merge_operations_to_circuit_op(
*,
tags_to_ignore: Sequence[Hashable] = (),
merged_circuit_op_tag: str = "Merged connected component",
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges connected components of operations and wraps each component into a circuit operation.

Expand All @@ -307,6 +342,8 @@ def merge_operations_to_circuit_op(
potential candidates for any connected component.
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
components.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.

Returns:
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
Expand All @@ -329,7 +366,7 @@ def get_ops(op: 'cirq.Operation'):
merged_circuit_op_tag
)

return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore)
return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore, deep=deep)


def merge_k_qubit_unitaries_to_circuit_op(
Expand All @@ -338,6 +375,7 @@ def merge_k_qubit_unitaries_to_circuit_op(
*,
tags_to_ignore: Sequence[Hashable] = (),
merged_circuit_op_tag: Optional[str] = None,
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges connected components of operations, acting on <= k qubits, into circuit operations.

Expand All @@ -353,6 +391,8 @@ def merge_k_qubit_unitaries_to_circuit_op(
potential candidates for any connected component.
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
components. A default tag is applied if left None.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.

Returns:
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
Expand All @@ -370,12 +410,16 @@ def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']
can_merge,
tags_to_ignore=tags_to_ignore,
merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.",
deep=deep,
)


def merge_moments(
circuit: CIRCUIT_TYPE,
merge_func: Callable[[circuits.Moment, circuits.Moment], Optional[circuits.Moment]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`.

Expand All @@ -384,12 +428,27 @@ def merge_moments(
merge_func: Callable to determine whether two adjacent moments in the circuit should be
merged. If the moments can be merged, the callable should return the merged moment,
else None.
tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
ignored when recursively applying the transformer primitive to sub-circuits, given
deep=True.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.

Returns:
Copy of input circuit with merged moments.
"""
if not circuit:
return circuit
if deep:
circuit = map_operations(
circuit,
lambda op, _: op.untagged.replace(
circuit=merge_moments(op.untagged.circuit, merge_func, deep=deep)
).with_tags(*op.tags)
if isinstance(op.untagged, circuits.CircuitOperation)
else op,
tags_to_ignore=tags_to_ignore,
)
merged_moments: List[circuits.Moment] = [circuit[0]]
for current_moment in circuit[1:]:
merged_moment = merge_func(merged_moments[-1], current_moment)
Expand Down
111 changes: 97 additions & 14 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,35 @@ def test_map_moments_drop_empty_moments():
cirq.testing.assert_same_circuits(c_mapped, cirq.Circuit(c[0], c[0]))


def test_map_moments_drop_empty_moments_deep():
op = cirq.X(cirq.NamedQubit("q"))
c_nested = cirq.FrozenCircuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op))
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
)
c_expected = cirq.Circuit(
[op, op],
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
[op, op],
cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(5).with_tags("preserve_tag"),
)
c_mapped = cirq.map_moments(
c_orig, lambda m, i: [] if len(m) == 0 else [m], deep=True, tags_to_ignore=("ignore",)
)
cirq.testing.assert_same_circuits(c_mapped, c_expected)


def _merge_z_moments_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
if any(op.gate != cirq.Z for m in [m1, m2] for op in m):
return None
return cirq.Moment(
cirq.Z(q) for q in (m1.qubits | m2.qubits) if m1.operates_on([q]) ^ m2.operates_on([q])
)


def test_merge_moments():
q = cirq.LineQubit.range(3)
c_orig = cirq.Circuit(
Expand All @@ -419,21 +448,8 @@ def test_merge_moments():
''',
)

def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
def is_z_moment(m):
return all(op.gate == cirq.Z for op in m)

if not (is_z_moment(m1) and is_z_moment(m2)):
return None
qubits = m1.qubits | m2.qubits

def mul(op1, op2):
return (op1 or op2) if not (op1 and op2) else cirq.decompose_once(op1 * op2)

return cirq.Moment(mul(m1.operation_at(q), m2.operation_at(q)) for q in qubits)

cirq.testing.assert_has_diagram(
cirq.merge_moments(c_orig, merge_func),
cirq.merge_moments(c_orig, _merge_z_moments_func),
'''
0: ───────@───────
Expand All @@ -444,6 +460,35 @@ def mul(op1, op2):
)


def test_merge_moments_deep():
q = cirq.LineQubit.range(3)
c_z_moments = cirq.Circuit(
[cirq.Z.on_each(q[0], q[1]), cirq.Z.on_each(q[1], q[2]), cirq.Z.on_each(q[1], q[0])],
strategy=cirq.InsertStrategy.NEW_THEN_INLINE,
)
merged_z_moment = cirq.Moment(cirq.Z.on_each(*q[1:]))
c_nested_circuit = cirq.FrozenCircuit(c_z_moments, cirq.CCX(*q), c_z_moments)
c_merged_circuit = cirq.FrozenCircuit(merged_z_moment, cirq.CCX(*q), merged_z_moment)
c_orig = cirq.Circuit(
cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"),
c_nested_circuit,
cirq.CircuitOperation(c_nested_circuit).repeat(6).with_tags("preserve_tag"),
c_nested_circuit,
cirq.CircuitOperation(c_nested_circuit).repeat(7),
)
c_expected = cirq.Circuit(
cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"),
c_merged_circuit,
cirq.CircuitOperation(c_merged_circuit).repeat(6).with_tags("preserve_tag"),
c_merged_circuit,
cirq.CircuitOperation(c_merged_circuit).repeat(7),
)
cirq.testing.assert_same_circuits(
cirq.merge_moments(c_orig, _merge_z_moments_func, tags_to_ignore=("ignore",), deep=True),
c_expected,
)


def test_merge_moments_empty_moment_as_intermediate_step():
q = cirq.NamedQubit("q")
c_orig = cirq.Circuit([cirq.X(q), cirq.Y(q), cirq.Z(q)] * 2, cirq.X(q) ** 0.5)
Expand Down Expand Up @@ -543,7 +588,45 @@ def merge_func(op1, op2):
)


def test_merge_operations_deep():
q = cirq.LineQubit.range(2)
h_cz_y = [cirq.H(q[0]), cirq.CZ(*q), cirq.Y(q[1])]
m_cz_m = [cirq.Moment(), cirq.Moment(cirq.CZ(*q)), cirq.Moment()]
c_orig = cirq.Circuit(
h_cz_y,
cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])),
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
[cirq.CNOT(*q), cirq.CNOT(*q)],
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(4),
[cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)],
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(5).with_tags("preserve_tag"),
)
c_expected = cirq.Circuit(
m_cz_m,
cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
[cirq.CNOT(*q), cirq.CNOT(*q)],
cirq.CircuitOperation(cirq.FrozenCircuit(m_cz_m)).repeat(4),
[cirq.CZ(*q), cirq.Moment(), cirq.Moment()],
cirq.CircuitOperation(cirq.FrozenCircuit(m_cz_m)).repeat(5).with_tags("preserve_tag"),
strategy=cirq.InsertStrategy.NEW,
)

def merge_func(op1, op2):
"""Artificial example where a CZ will absorb any merge-able operation."""
for op in [op1, op2]:
if op.gate == cirq.CZ:
return op
return None

cirq.testing.assert_same_circuits(
cirq.merge_operations(c_orig, merge_func, tags_to_ignore=["ignore"], deep=True), c_expected
)


# pylint: disable=line-too-long


def test_merge_operations_to_circuit_op_merges_connected_component():
c_orig = _create_circuit_to_merge()
cirq.testing.assert_has_diagram(
Expand Down