diff --git a/cirq-core/cirq/transformers/dynamical_decoupling.py b/cirq-core/cirq/transformers/dynamical_decoupling.py index 0c3ffb912d5..31e7704f257 100644 --- a/cirq-core/cirq/transformers/dynamical_decoupling.py +++ b/cirq-core/cirq/transformers/dynamical_decoupling.py @@ -15,41 +15,47 @@ """Transformer pass that adds dynamical decoupling operations to a circuit.""" from functools import reduce -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, TYPE_CHECKING from itertools import cycle +from cirq import ops, circuits, protocols from cirq.transformers import transformer_api from cirq.transformers.analytical_decompositions import single_qubit_decompositions -from cirq.transformers.analytical_decompositions import unitary_to_pauli_string -import cirq +from cirq.protocols import unitary_protocol +from cirq.protocols.has_unitary_protocol import has_unitary +from cirq.protocols.has_stabilizer_effect_protocol import has_stabilizer_effect + import numpy as np +if TYPE_CHECKING: + import cirq + -def _get_dd_sequence_from_schema_name(schema: str) -> Tuple['cirq.Gate', ...]: +def _get_dd_sequence_from_schema_name(schema: str) -> Tuple[ops.Gate, ...]: """Gets dynamical decoupling sequence from a schema name.""" match schema: case 'DEFAULT': - return (cirq.X, cirq.Y, cirq.X, cirq.Y) + return (ops.X, ops.Y, ops.X, ops.Y) case 'XX_PAIR': - return (cirq.X, cirq.X) + return (ops.X, ops.X) case 'X_XINV': - return (cirq.X, cirq.X**-1) + return (ops.X, ops.X**-1) case 'YY_PAIR': - return (cirq.Y, cirq.Y) + return (ops.Y, ops.Y) case 'Y_YINV': - return (cirq.Y, cirq.Y**-1) + return (ops.Y, ops.Y**-1) case _: raise ValueError('Invalid schema name.') -def _pauli_up_to_global_phase(gate: 'cirq.Gate') -> Union['cirq.Pauli', None]: - for pauli_gate in [cirq.X, cirq.Y, cirq.Z]: - if cirq.equal_up_to_global_phase(gate, pauli_gate): +def _pauli_up_to_global_phase(gate: ops.Gate) -> Union[ops.Pauli, None]: + for pauli_gate in [ops.X, ops.Y, ops.Z]: + if protocols.equal_up_to_global_phase(gate, pauli_gate): return pauli_gate return None -def _validate_dd_sequence(dd_sequence: Tuple['cirq.Gate', ...]) -> None: +def _validate_dd_sequence(dd_sequence: Tuple[ops.Gate, ...]) -> None: """Validates a given dynamical decoupling sequence. Args: @@ -66,144 +72,132 @@ def _validate_dd_sequence(dd_sequence: Tuple['cirq.Gate', ...]) -> None: 'Dynamical decoupling sequence should only contain gates that are essentially' ' Pauli gates.' ) - matrices = [cirq.unitary(gate) for gate in dd_sequence] + matrices = [unitary_protocol.unitary(gate) for gate in dd_sequence] product = reduce(np.matmul, matrices) - if not cirq.equal_up_to_global_phase(product, np.eye(2)): + if not protocols.equal_up_to_global_phase(product, np.eye(2)): raise ValueError( 'Invalid dynamical decoupling sequence. Expect sequence production equals' f' identity up to a global phase, got {product}.'.replace('\n', ' ') ) -def _parse_dd_sequence(schema: Union[str, Tuple['cirq.Gate', ...]]) -> Tuple['cirq.Gate', ...]: - """Parses and returns dynamical decoupling sequence from schema.""" +def _parse_dd_sequence( + schema: Union[str, Tuple[ops.Gate, ...]] +) -> Tuple[Tuple[ops.Gate, ...], Dict[ops.Gate, ops.Pauli]]: + """Parses and returns dynamical decoupling sequence and its associated pauli map from schema.""" + dd_sequence = None if isinstance(schema, str): - return _get_dd_sequence_from_schema_name(schema) + dd_sequence = _get_dd_sequence_from_schema_name(schema) else: _validate_dd_sequence(schema) - return schema + dd_sequence = schema + # Map Gate to Puali gate. This is necessary as dd sequence might contain gates like X^-1. + pauli_map: Dict[ops.Gate, ops.Pauli] = {} + for gate in dd_sequence: + pauli_gate = _pauli_up_to_global_phase(gate) + if pauli_gate is not None: + pauli_map[gate] = pauli_gate + for gate in [ops.X, ops.Y, ops.Z]: + pauli_map[gate] = gate -def _is_single_qubit_operation(operation: 'cirq.Operation') -> bool: - if len(operation.qubits) != 1: - return False - return True + return (dd_sequence, pauli_map) + + +def _is_single_qubit_operation(operation: ops.Operation) -> bool: + return len(operation.qubits) == 1 -def _is_single_qubit_gate_moment(moment: 'cirq.Moment') -> bool: - for operation in moment: - if not _is_single_qubit_operation(operation): - return False - return True +def _is_single_qubit_gate_moment(moment: circuits.Moment) -> bool: + return all(_is_single_qubit_operation(op) for op in moment) -def _is_clifford_moment(moment: 'cirq.Moment') -> bool: - for op in moment.operations: - if op.gate is not None and isinstance(op.gate, cirq.MeasurementGate): - return False - if not cirq.has_stabilizer_effect(op): - return False - return True +def _is_clifford_op(op: ops.Operation) -> bool: + return has_unitary(op) and has_stabilizer_effect(op) -def _get_clifford_pieces(circuit: 'cirq.AbstractCircuit') -> list[Tuple[int, int]]: - clifford_pieces: list[Tuple[int, int]] = [] - left = 0 +def _calc_busy_moment_range_of_each_qubit( + circuit: circuits.FrozenCircuit, +) -> Dict[ops.Qid, list[int]]: + busy_moment_range_by_qubit: Dict[ops.Qid, list[int]] = { + q: [len(circuit), -1] for q in circuit.all_qubits() + } for moment_id, moment in enumerate(circuit): - if not _is_clifford_moment(moment): - clifford_pieces.append((left, moment_id)) - left = moment_id + 1 - if left < len(circuit): - clifford_pieces.append((left, len(circuit))) - return clifford_pieces + for q in moment.qubits: + busy_moment_range_by_qubit[q][0] = min(busy_moment_range_by_qubit[q][0], moment_id) + busy_moment_range_by_qubit[q][1] = max(busy_moment_range_by_qubit[q][1], moment_id) + return busy_moment_range_by_qubit -def _is_insertable_moment(moment: 'cirq.Moment', single_qubit_gate_moments_only: bool) -> bool: - return _is_single_qubit_gate_moment(moment) or not single_qubit_gate_moments_only +def _is_insertable_moment(moment: circuits.Moment, single_qubit_gate_moments_only: bool) -> bool: + return not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment) -def _calc_pulled_through( - moment: 'cirq.Moment', input_pauli_ops: 'cirq.PauliString' -) -> 'cirq.PauliString': - """Calculates the pulled_through after pulling through moment with the input. +def _merge_single_qubit_ops_to_phxz( + q: ops.Qid, operations: Tuple[ops.Operation, ...] +) -> ops.Operation: + """Merges [op1, op2, ...] and returns an equivalent op""" + if len(operations) == 1: + return operations[0] + matrices = [unitary_protocol.unitary(op) for op in reversed(operations)] + product = reduce(np.matmul, matrices) + gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(product) or ops.I + return gate.on(q) + + +def _try_merge_single_qubit_ops_of_two_moments( + m1: circuits.Moment, m2: circuits.Moment +) -> Tuple[circuits.Moment, ...]: + """Merge single qubit ops of 2 moments if possible, returns 2 moments otherwise.""" + for q in m1.qubits & m2.qubits: + op1 = m1.operation_at(q) + op2 = m2.operation_at(q) + if any( + not (_is_single_qubit_operation(op) and has_unitary(op)) + for op in [op1, op2] + if op is not None + ): + return (m1, m2) + merged_ops: set[ops.Operation] = set() + # Merge all operators on q to a single op. + for q in m1.qubits | m2.qubits: + # ops_on_q may contain 1 op or 2 ops. + ops_on_q = [op for op in [m.operation_at(q) for m in [m1, m2]] if op is not None] + merged_ops.add(_merge_single_qubit_ops_to_phxz(q, tuple(ops_on_q))) + return (circuits.Moment(merged_ops),) - We assume that the moment is Clifford here. Then, pulling through is essentially - decomposing a matrix into Pauli operations on each qubit. - """ - pulled_through: 'cirq.PauliString' = cirq.PauliString() - for affected_q, combined_op_in_pauli in input_pauli_ops.items(): - op_at_moment = moment.operation_at(affected_q) - if op_at_moment is None: - pulled_through *= combined_op_in_pauli.on(affected_q) - continue - prev_circuit = cirq.Circuit(cirq.Moment(op_at_moment)) - new_circuit = cirq.Circuit( - cirq.Moment(combined_op_in_pauli.on(affected_q)), cirq.Moment(op_at_moment) - ) - qubit_order = op_at_moment.qubits - pulled_through_pauli_ops = unitary_to_pauli_string( - prev_circuit.unitary(qubit_order=qubit_order) - @ new_circuit.unitary(qubit_order=qubit_order).conj().T - ) - if pulled_through_pauli_ops is not None: - for qid, gate in enumerate(pulled_through_pauli_ops): - pulled_through *= gate.on(qubit_order[qid]) - return pulled_through - - -def _merge_pulled_through( - mutable_circuit: 'cirq.Circuit', - pulled_through: 'cirq.PauliString', - clifford_piece_range: Tuple[int, int], - single_qubit_gate_moments_only: bool, -) -> 'cirq.PauliString': - """Merges pulled through Pauli gates into the last single-qubit gate operation or the insert it - into the first idle moment if idle moments exist. - Args: - mutable_circuit: Mutable circuit to transform. - pulled_through: Pauli gates to be merged. - clifford_piece_range: Specifies the [l, r) moments within which pulled-through gate merging - is to be performed. - single_qubit_gate_moments_only: If set True, dynamical decoupling operation will only be - added in single-qubit gate moments. - Returns: - The remaining pulled through operations after merging. +def _calc_pulled_through( + moment: circuits.Moment, input_pauli_ops: ops.PauliString +) -> ops.PauliString: + """Calculates the pulled_through such that circuit(input_puali_ops, moment.clifford_ops) is + equivalent to circuit(moment.clifford_ops, pulled_through). """ - insert_intos: list[Tuple[int, 'cirq.Operation']] = [] - batch_replaces: list[Tuple[int, 'cirq.Operation', 'cirq.Operation']] = [] - remaining_pulled_through = pulled_through - for affected_q, combined_op_in_pauli in pulled_through.items(): - moment_id = mutable_circuit.prev_moment_operating_on([affected_q], clifford_piece_range[1]) - if moment_id is not None: - op = mutable_circuit.operation_at(affected_q, moment_id) - # Try to merge op into an existing single-qubit gate operation. - if op is not None and _is_single_qubit_operation(op): - updated_gate_mat = cirq.unitary(combined_op_in_pauli) @ cirq.unitary(op) - updated_gate: Optional['cirq.Gate'] = ( - single_qubit_decompositions.single_qubit_matrix_to_phxz(updated_gate_mat) - ) - if updated_gate is None: - # updated_gate is close to Identity. - updated_gate = cirq.I - batch_replaces.append((moment_id, op, updated_gate.on(affected_q))) - remaining_pulled_through *= combined_op_in_pauli.on(affected_q) - continue - # Insert into the first empty moment for the qubit if such moment exists. - while moment_id < clifford_piece_range[1]: - if affected_q not in mutable_circuit.moments[ - moment_id - ].qubits and _is_insertable_moment( - mutable_circuit.moments[moment_id], single_qubit_gate_moments_only - ): - insert_intos.append((moment_id, combined_op_in_pauli.on(affected_q))) - remaining_pulled_through *= combined_op_in_pauli.on(affected_q) - break - moment_id += 1 - mutable_circuit.batch_insert_into(insert_intos) - mutable_circuit.batch_replace(batch_replaces) - return remaining_pulled_through + clifford_ops_in_moment: list[ops.Operation] = [ + op for op in moment.operations if _is_clifford_op(op) + ] + return input_pauli_ops.after(clifford_ops_in_moment) + + +def _get_stop_qubits(moment: circuits.Moment) -> set[ops.Qid]: + stop_pulling_through_qubits: set[ops.Qid] = set() + for op in moment: + if (not _is_clifford_op(op) and not _is_single_qubit_operation(op)) or not has_unitary( + op + ): # multi-qubit clifford op or non-mergable op. + stop_pulling_through_qubits.update(op.qubits) + return stop_pulling_through_qubits + + +def _need_merge_pulled_through(op_at_q: ops.Operation, is_at_last_busy_moment: bool) -> bool: + """With a pulling through puali gate before op_at_q, need to merge with the + pauli in the conditions below.""" + # The op must be mergable and single-qubit + if not (_is_single_qubit_operation(op_at_q) and has_unitary(op_at_q)): + return False + # Either non-Clifford or at the last busy moment + return is_at_last_busy_moment or not _is_clifford_op(op_at_q) @transformer_api.transformer @@ -211,12 +205,11 @@ def add_dynamical_decoupling( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None, - schema: Union[str, Tuple['cirq.Gate', ...]] = 'DEFAULT', + schema: Union[str, Tuple[ops.Gate, ...]] = 'DEFAULT', single_qubit_gate_moments_only: bool = True, ) -> 'cirq.Circuit': """Adds dynamical decoupling gate operations to a given circuit. - This transformer might add a new moment after each piece of Clifford moments, so the original - moment structure could change. + This transformer might add new moments thus change structure of the original circuit. Args: circuit: Input circuit to transform. @@ -230,64 +223,110 @@ def add_dynamical_decoupling( Returns: A copy of the input circuit with dynamical decoupling operations. """ - base_dd_sequence: Tuple['cirq.Gate', ...] = _parse_dd_sequence(schema) - mutable_circuit = circuit.unfreeze(copy=True) + base_dd_sequence, pauli_map = _parse_dd_sequence(schema) + orig_circuit = circuit.freeze() + + busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(orig_circuit) + + # Stores all the moments of the output circuit chronically + transformed_moments: list[circuits.Moment] = [] + # A PauliString stores the result of 'pulling' Pauli gates past each operations + # right before the current moment. + pulled_through: ops.PauliString = ops.PauliString() + # Iterator of gate to be used in dd sequence for each qubit. + dd_iter_by_qubits = {q: cycle(base_dd_sequence) for q in circuit.all_qubits()} + + def _update_pulled_through(q: ops.Qid, insert_gate: ops.Gate) -> ops.Operation: + nonlocal pulled_through, pauli_map + pulled_through *= pauli_map[insert_gate].on(q) + return insert_gate.on(q) + + # Insert and pull remaining Puali ops through the whole circuit. + # General ideas are + # * Pull through Clifford gates. + # * Stop at multi-qubit non-Clifford ops (and other non-mergable ops). + # * Merge to single-qubit non-Clifford ops. + # * Insert a new moment if necessary. + # After pulling through pulled_through at `moment`, we expect a transformation of + # (pulled_through, moment) -> (updated_moment, updated_pulled_through) or + # (pulled_through, moment) -> (new_moment, updated_moment, updated_pulled_through) + # Moments structure changes are split into 3 steps: + # 1, (..., last_moment, pulled_through1, moment, ...) + # -> (..., try_merge(last_moment, new_moment or None), pulled_through2, moment, ...) + # 2, (..., pulled_through2, moment, ...) -> (..., pulled_through3, updated_moment, ...) + # 3, (..., pulled_through3, updated_moment, ...) + # -> (..., updated_moment, pulled_through4, ...) + for moment_id, moment in enumerate(orig_circuit.moments): + # Step 1, insert new_moment if necessary. + # In detail: stop pulling through for multi-qubit non-Clifford ops or gates without + # unitary representation (e.g., measure gates). If there are remaining pulled through ops, + # insert into a new moment before current moment. + stop_pulling_through_qubits: set[ops.Qid] = _get_stop_qubits(moment) + new_moment_ops = [] + for q in stop_pulling_through_qubits: + # Insert the remaining pulled_through + remaining_pulled_through_gate = pulled_through.get(q) + if remaining_pulled_through_gate is not None: + new_moment_ops.append(_update_pulled_through(q, remaining_pulled_through_gate)) + # Reset dd sequence + dd_iter_by_qubits[q] = cycle(base_dd_sequence) + # Need to insert a new moment before current moment + if new_moment_ops: + # Fill insertable idle moments in the new moment using dd sequence + for q in orig_circuit.all_qubits() - stop_pulling_through_qubits: + if busy_moment_range_by_qubit[q][0] < moment_id <= busy_moment_range_by_qubit[q][1]: + new_moment_ops.append(_update_pulled_through(q, next(dd_iter_by_qubits[q]))) + moments_to_be_appended = _try_merge_single_qubit_ops_of_two_moments( + transformed_moments.pop(), circuits.Moment(new_moment_ops) + ) + transformed_moments.extend(moments_to_be_appended) + + # Step 2, calc updated_moment with insertions / merges. + updated_moment_ops: set['cirq.Operation'] = set() + for q in orig_circuit.all_qubits(): + op_at_q = moment.operation_at(q) + remaining_pulled_through_gate = pulled_through.get(q) + updated_op = op_at_q + if op_at_q is None: # insert into idle op + if not _is_insertable_moment(moment, single_qubit_gate_moments_only): + continue + if ( + busy_moment_range_by_qubit[q][0] < moment_id < busy_moment_range_by_qubit[q][1] + ): # insert next pauli gate in the dd sequence + updated_op = _update_pulled_through(q, next(dd_iter_by_qubits[q])) + elif ( # insert the remaining pulled through if beyond the ending busy moment + moment_id > busy_moment_range_by_qubit[q][1] + and remaining_pulled_through_gate is not None + ): + updated_op = _update_pulled_through(q, remaining_pulled_through_gate) + elif ( + remaining_pulled_through_gate is not None + ): # merge pulled-through of q to op_at_q if needed + if _need_merge_pulled_through( + op_at_q, moment_id == busy_moment_range_by_qubit[q][1] + ): + remaining_op = _update_pulled_through(q, remaining_pulled_through_gate) + updated_op = _merge_single_qubit_ops_to_phxz(q, (remaining_op, op_at_q)) + if updated_op is not None: + updated_moment_ops.add(updated_op) - pauli_map: Dict['cirq.Gate', 'cirq.Pauli'] = {} - for gate in base_dd_sequence: - pauli_gate = _pauli_up_to_global_phase(gate) - if pauli_gate is not None: - pauli_map[gate] = pauli_gate + if updated_moment_ops: + updated_moment = circuits.Moment(updated_moment_ops) + transformed_moments.append(updated_moment) - busy_moment_range_by_qubit: Dict['cirq.Qid', list[int]] = { - q: [len(circuit), -1] for q in circuit.all_qubits() - } - for moment_id, moment in enumerate(circuit): - for q in moment.qubits: - busy_moment_range_by_qubit[q][0] = min(busy_moment_range_by_qubit[q][0], moment_id) - busy_moment_range_by_qubit[q][1] = max(busy_moment_range_by_qubit[q][1], moment_id) - clifford_pieces = _get_clifford_pieces(circuit) - - insert_intos: list[Tuple[int, 'cirq.Operation']] = [] - insert_moments: list[Tuple[int, 'cirq.Moment']] = [] - for l, r in clifford_pieces: # [l, r) - # A PauliString stores the result of 'pulling' Pauli gates past each operations - # right before the current moment. - pulled_through: 'cirq.PauliString' = cirq.PauliString() - iter_by_qubits = {q: cycle(base_dd_sequence) for q in circuit.all_qubits()} - - # Iterate over the Clifford piece. - for moment_id in range(l, r): - moment = circuit.moments[moment_id] - - # Insert - if _is_insertable_moment(moment, single_qubit_gate_moments_only): - for q in circuit.all_qubits() - moment.qubits: - if ( - busy_moment_range_by_qubit[q][0] - < moment_id - < busy_moment_range_by_qubit[q][1] - ): - insert_gate = next(iter_by_qubits[q]) - insert_intos.append((moment_id, insert_gate.on(q))) - pulled_through *= pauli_map[insert_gate].on(q) - - # Pull through - pulled_through = _calc_pulled_through(moment, pulled_through) - - mutable_circuit.batch_insert_into(insert_intos) - insert_intos.clear() - - pulled_through = _merge_pulled_through( - mutable_circuit, pulled_through, (l, r), single_qubit_gate_moments_only - ) + # Step 3, update pulled through. + # In detail: pulling current `pulled_through` through updated_moment. + pulled_through = _calc_pulled_through(updated_moment, pulled_through) - # Insert a new moment if there are remaining pulled through operations. - new_moment_ops = [] - for affected_q, combined_op_in_pauli in pulled_through.items(): - new_moment_ops.append(combined_op_in_pauli.on(affected_q)) - if len(new_moment_ops) != 0: - insert_moments.append((r, cirq.Moment(new_moment_ops))) + # Insert a new moment if there are remaining pulled-through operations. + ending_moment_ops = [] + for affected_q, combined_op_in_pauli in pulled_through.items(): + ending_moment_ops.append(combined_op_in_pauli.on(affected_q)) + if ending_moment_ops: + transformed_moments.extend( + _try_merge_single_qubit_ops_of_two_moments( + transformed_moments.pop(), circuits.Moment(ending_moment_ops) + ) + ) - mutable_circuit.batch_insert(insert_moments) - return mutable_circuit + return circuits.Circuit(transformed_moments) diff --git a/cirq-core/cirq/transformers/dynamical_decoupling_test.py b/cirq-core/cirq/transformers/dynamical_decoupling_test.py index 79a2bdd5317..31bdcd3607c 100644 --- a/cirq-core/cirq/transformers/dynamical_decoupling_test.py +++ b/cirq-core/cirq/transformers/dynamical_decoupling_test.py @@ -30,14 +30,17 @@ def assert_sim_eq(circuit1: 'cirq.AbstractCircuit', circuit2: 'cirq.AbstractCirc def assert_dd( input_circuit: 'cirq.AbstractCircuit', - expected_circuit: 'cirq.AbstractCircuit', + expected_circuit: Union[str, 'cirq.AbstractCircuit'], schema: Union[str, Tuple['cirq.Gate', ...]] = 'DEFAULT', single_qubit_gate_moments_only: bool = True, ): transformed_circuit = add_dynamical_decoupling( input_circuit, schema=schema, single_qubit_gate_moments_only=single_qubit_gate_moments_only ).freeze() - cirq.testing.assert_same_circuits(transformed_circuit, expected_circuit) + if isinstance(expected_circuit, str): + cirq.testing.assert_has_diagram(transformed_circuit, expected_circuit) + else: + cirq.testing.assert_same_circuits(transformed_circuit, expected_circuit) cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( cirq.drop_terminal_measurements(input_circuit), cirq.drop_terminal_measurements(transformed_circuit), @@ -533,7 +536,7 @@ def test_pull_through_chain(): ) -def test_multiple_clifford_pieces(): +def test_multiple_clifford_pieces_case1(): """Test case diagrams. Input: a: ───H───────H───────@───────────H───────H─── @@ -579,6 +582,83 @@ def test_multiple_clifford_pieces(): ) +def test_multiple_clifford_pieces_case2(): + """Test case diagrams. + Input: + a: ───@───PhXZ(a=0.3,x=0.2,z=0)───PhXZ(a=0.3,x=0.2,z=0)───PhXZ(a=0.3,x=0.2,z=0)───@─── + │ │ + b: ───@───────────────────────────────────────────────────────────────────────────@─── + Output: + a: ───@───PhXZ(a=0.3,x=0.2,z=0)───PhXZ(a=0.3,x=0.2,z=0)───PhXZ(a=0.3,x=0.2,z=0)───@───Z─── + │ │ + b: ───@───X───────────────────────X───────────────────────X───────────────────────@───X─── + """ + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + phased_xz_gate = cirq.PhasedXZGate(axis_phase_exponent=0.3, x_exponent=0.2, z_exponent=0) + + assert_dd( + input_circuit=cirq.Circuit( + cirq.Moment(cirq.CZ(a, b)), + cirq.Moment(phased_xz_gate.on(a)), + cirq.Moment(phased_xz_gate.on(a)), + cirq.Moment(phased_xz_gate.on(a)), + cirq.Moment(cirq.CZ(a, b)), + ), + expected_circuit=cirq.Circuit( + cirq.Moment(cirq.CZ(a, b)), + cirq.Moment(phased_xz_gate.on(a), cirq.X(b)), + cirq.Moment(phased_xz_gate.on(a), cirq.X(b)), + cirq.Moment(phased_xz_gate.on(a), cirq.X(b)), + cirq.Moment(cirq.CZ(a, b)), + cirq.Moment(cirq.Z(a), cirq.X(b)), + ), + schema='XX_PAIR', + single_qubit_gate_moments_only=False, + ) + + +def test_insert_new_moment(): + """Test case diagrams. + Input: + a: ───H───────H───@───@─────── + │ │ + b: ───H───H───H───X───@^0.5─── + + c: ───H───────────────H─────── + Output: + a: ───H───X───H───@───Z───@──────────────────────── + │ │ + b: ───H───H───H───X───────@^0.5──────────────────── + + c: ───H───X───X───────X───PhXZ(a=-0.5,x=0.5,z=0)─── + """ + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + c = cirq.NamedQubit('c') + assert_dd( + input_circuit=cirq.Circuit( + cirq.Moment(cirq.H(a), cirq.H(b), cirq.H(c)), + cirq.Moment(cirq.H(b)), + cirq.Moment(cirq.H(b), cirq.H(a)), + cirq.Moment(cirq.CNOT(a, b)), + cirq.Moment(cirq.CZPowGate(exponent=0.5).on(a, b), cirq.H(c)), + ), + expected_circuit=cirq.Circuit( + cirq.Moment(cirq.H(a), cirq.H(b), cirq.H(c)), + cirq.Moment(cirq.H(b), cirq.X(a), cirq.X(c)), + cirq.Moment(cirq.H(a), cirq.H(b), cirq.X(c)), + cirq.Moment(cirq.CNOT(a, b)), + cirq.Moment(cirq.Z(a), cirq.X(c)), + cirq.Moment( + cirq.CZPowGate(exponent=0.5).on(a, b), + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0).on(c), + ), + ), + schema="XX_PAIR", + ) + + def test_with_non_clifford_measurements(): """Test case diagrams. Input: @@ -626,3 +706,71 @@ def test_with_non_clifford_measurements(): schema="XX_PAIR", single_qubit_gate_moments_only=True, ) + + +def test_cross_clifford_pieces_filling_merge(): + # pylint: disable=line-too-long + """Test case diagrams. + Input: + 0: ─────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───H─── + │ │ + 1: ─────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───H─── + + 2: ───PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────────H─── + │ │ │ + 3: ─────────────────────────────┼───PhXZ(a=0.2,x=0.2,z=0.1)───@───────────────────────────────────────────────────@─────────────────────────────H─── + │ + 4: ─────────────────────────────┼─────────────────────────────@─────────────────────────────────────────────────────────────────────────────────H─── + │ │ + 5: ───PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───H─── + │ + 6: ───────────────────────────────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)─────────────────────────────@───PhXZ(a=0.2,x=0.2,z=0.1)───H─── + Output: + + 0: ─────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)─────H──────────────────────── + │ │ + 1: ─────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)─────H──────────────────────── + + 2: ───PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───X───────────────────────────PhXZ(a=0.5,x=0.5,z=-1)─── + │ │ │ + 3: ─────────────────────────────┼───PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────X─────────────────────────@───Y───────────────────────────PhXZ(a=0.5,x=0.5,z=0)──── + │ + 4: ─────────────────────────────┼─────────────────────────────@─────────────────────────X─────────────────────────────Y───────────────────────────PhXZ(a=0.5,x=0.5,z=0)──── + │ │ + 5: ───PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=-0.8,x=0.2,z=-0.9)───H──────────────────────── + │ + 6: ───────────────────────────────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───X─────────────────────────@───PhXZ(a=0.8,x=0.8,z=0.5)─────H──────────────────────── + """ + # pylint: enable + qubits = cirq.LineQubit.range(7) + phased_xz_gate = cirq.PhasedXZGate(axis_phase_exponent=0.2, x_exponent=0.2, z_exponent=0.1) + assert_dd( + input_circuit=cirq.Circuit( + cirq.Moment([phased_xz_gate.on(qubits[i]) for i in [2, 5]]), + cirq.Moment(cirq.CZ(qubits[2], qubits[5])), + cirq.Moment([phased_xz_gate.on(qubits[i]) for i in [0, 1, 2, 3, 5]]), + cirq.Moment( + [cirq.CZ(qubits[i0], qubits[i1]) for i0, i1 in [(0, 1), (2, 3), (4, 5)]] + + [phased_xz_gate.on(qubits[6])] + ), + cirq.Moment([phased_xz_gate.on(qubits[i]) for i in [0, 1, 2, 5]]), + cirq.Moment([cirq.CZ(qubits[i0], qubits[i1]) for i0, i1 in [(0, 1), (2, 3), (5, 6)]]), + cirq.Moment([phased_xz_gate.on(qubits[i]) for i in [0, 1, 5, 6]]), + cirq.Moment([cirq.H.on(q) for q in qubits]), + ), + expected_circuit=""" +0: ─────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)─────H──────────────────────── + │ │ +1: ─────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)─────H──────────────────────── + +2: ───PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───X───────────────────────────PhXZ(a=0.5,x=0.5,z=-1)─── + │ │ │ +3: ─────────────────────────────┼───PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────X─────────────────────────@───Y───────────────────────────PhXZ(a=0.5,x=0.5,z=0)──── + │ +4: ─────────────────────────────┼─────────────────────────────@─────────────────────────X─────────────────────────────Y───────────────────────────PhXZ(a=0.5,x=0.5,z=0)──── + │ │ +5: ───PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=0.2,x=0.2,z=0.1)───@─────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───@───PhXZ(a=-0.8,x=0.2,z=-0.9)───H──────────────────────── + │ +6: ───────────────────────────────────────────────────────────PhXZ(a=0.2,x=0.2,z=0.1)───X─────────────────────────@───PhXZ(a=0.8,x=0.8,z=0.5)─────H──────────────────────── +""", + )