From 7a4e7598fcbd67ea27f026ed4ef46707cbb9ff90 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Thu, 7 Apr 2022 06:02:11 -0700 Subject: [PATCH 1/2] Make _commutes_ consistent - Requires atol to be a named parameter. - Also changes atol to be uniformly float around the codebase. (not sure why it would be int, are people using an atol=1?) - Technically a breaking change, but it's unlikely people are using this widely as most commutes do not even use atol. Fixes: #3695 --- cirq-core/cirq/circuits/moment.py | 4 +--- cirq-core/cirq/contrib/acquaintance/permutation.py | 4 +--- cirq-core/cirq/ops/clifford_gate.py | 6 ++++-- cirq-core/cirq/ops/common_gates.py | 2 +- cirq-core/cirq/ops/dense_pauli_string.py | 4 +++- cirq-core/cirq/ops/gate_operation.py | 2 +- cirq-core/cirq/ops/pauli_gates.py | 4 +++- cirq-core/cirq/ops/pauli_string.py | 2 +- cirq-core/cirq/ops/raw_types.py | 8 ++++---- cirq-core/cirq/protocols/commutes_protocol.py | 2 +- 10 files changed, 20 insertions(+), 18 deletions(-) diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index c4907c2ab6f..8b0f4b56715 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -574,9 +574,7 @@ def cleanup_key(key: Any) -> Any: return diagram.render() - def _commutes_( - self, other: Any, *, atol: Union[int, float] = 1e-8 - ) -> Union[bool, NotImplementedType]: + def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]: """Determines whether Moment commutes with the Operation. Args: diff --git a/cirq-core/cirq/contrib/acquaintance/permutation.py b/cirq-core/cirq/contrib/acquaintance/permutation.py index e0a87a8350f..6e18956a116 100644 --- a/cirq-core/cirq/contrib/acquaintance/permutation.py +++ b/cirq-core/cirq/contrib/acquaintance/permutation.py @@ -164,9 +164,7 @@ def __repr__(self) -> str: def _value_equality_values_(self) -> Any: return (self.swap_gate,) - def _commutes_( - self, other: Any, atol: Union[int, float] = 1e-8 - ) -> Union[bool, NotImplementedType]: + def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]: if ( isinstance(other, ops.Gate) and isinstance(other, ops.InterchangeableQubitsGate) diff --git a/cirq-core/cirq/ops/clifford_gate.py b/cirq-core/cirq/ops/clifford_gate.py index 4035770113e..3b001154782 100644 --- a/cirq-core/cirq/ops/clifford_gate.py +++ b/cirq-core/cirq/ops/clifford_gate.py @@ -383,7 +383,7 @@ def __pow__(self, exponent) -> 'SingleQubitCliffordGate': return SingleQubitCliffordGate.from_clifford_tableau(self.clifford_tableau.inverse()) - def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType]: + def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]: if isinstance(other, SingleQubitCliffordGate): return self.commutes_with_single_qubit_gate(other) if isinstance(other, Pauli): @@ -838,7 +838,9 @@ def __pow__(self, exponent) -> 'CliffordGate': def __repr__(self) -> str: return f"Clifford Gate with Tableau:\n {self.clifford_tableau._str_full_()}" - def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]: + def _commutes_( + self, other: Any, *, atol: float = 1e-8 + ) -> Union[bool, NotImplementedType, None]: # Note even if we assume two gates define the tabluea based on the same qubit order, # the following approach cannot judge it: # self.clifford_tableau.then(other.clifford_tableau) == other.clifford_tableau.then( diff --git a/cirq-core/cirq/ops/common_gates.py b/cirq-core/cirq/ops/common_gates.py index 9ceebe1424c..7f836e196d5 100644 --- a/cirq-core/cirq/ops/common_gates.py +++ b/cirq-core/cirq/ops/common_gates.py @@ -662,7 +662,7 @@ def __repr__(self) -> str: ) def _commutes_on_qids_( - self, qids: 'Sequence[cirq.Qid]', other: Any, atol: float + self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float ) -> Union[bool, NotImplementedType, None]: from cirq.ops.parity_gates import ZZPowGate diff --git a/cirq-core/cirq/ops/dense_pauli_string.py b/cirq-core/cirq/ops/dense_pauli_string.py index 143b9e13f61..12fbd7b2ba5 100644 --- a/cirq-core/cirq/ops/dense_pauli_string.py +++ b/cirq-core/cirq/ops/dense_pauli_string.py @@ -358,7 +358,9 @@ def __repr__(self) -> str: f'coefficient={proper_repr(self.coefficient)})' ) - def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]: + def _commutes_( + self, other: Any, *, atol: float = 1e-8 + ) -> Union[bool, NotImplementedType, None]: if isinstance(other, BaseDensePauliString): n = min(len(self.pauli_mask), len(other.pauli_mask)) phase = _vectorized_pauli_mul_phase(self.pauli_mask[:n], other.pauli_mask[:n]) diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index f36e5b8f54d..6f80710d42d 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -199,7 +199,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: return NotImplemented def _commutes_( - self, other: Any, atol: Union[int, float] = 1e-8 + self, other: Any, *, atol: float = 1e-8 ) -> Union[bool, NotImplementedType, None]: commutes = self.gate._commutes_on_qids_(self.qubits, other, atol=atol) if commutes is not NotImplemented: diff --git a/cirq-core/cirq/ops/pauli_gates.py b/cirq-core/cirq/ops/pauli_gates.py index f5cb6c1f4bd..7fa27f5fae3 100644 --- a/cirq-core/cirq/ops/pauli_gates.py +++ b/cirq-core/cirq/ops/pauli_gates.py @@ -53,7 +53,9 @@ def __init__(self, index: int, name: str) -> None: def num_qubits(self): return 1 - def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]: + def _commutes_( + self, other: Any, *, atol: float = 1e-8 + ) -> Union[bool, NotImplementedType, None]: if not isinstance(other, Pauli): return NotImplemented return self is other diff --git a/cirq-core/cirq/ops/pauli_string.py b/cirq-core/cirq/ops/pauli_string.py index 9d146639143..01b922b8207 100644 --- a/cirq-core/cirq/ops/pauli_string.py +++ b/cirq-core/cirq/ops/pauli_string.py @@ -678,7 +678,7 @@ def zip_paulis( return (paulis for qubit, paulis in self.zip_items(other)) def _commutes_( - self, other: Any, *, atol: Union[int, float] = 1e-8 + self, other: Any, *, atol: float = 1e-8 ) -> Union[bool, NotImplementedType, None]: if not isinstance(other, PauliString): return NotImplemented diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 8369646be8a..fb25359d5f0 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -399,11 +399,11 @@ def _qid_shape_(self) -> Tuple[int, ...]: """ def _commutes_on_qids_( - self, qids: 'Sequence[cirq.Qid]', other: Any, atol: float + self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float ) -> Union[bool, NotImplementedType, None]: return NotImplemented - def _commutes_(self, other: Any, atol: float) -> Union[None, NotImplementedType, bool]: + def _commutes_(self, other: Any, *, atol: float) -> Union[None, NotImplementedType, bool]: if not isinstance(other, Gate): return NotImplemented if protocols.qid_shape(self) != protocols.qid_shape(other): @@ -567,7 +567,7 @@ def validate_args(self, qubits: Sequence['cirq.Qid']): _validate_qid_shape(self, qubits) def _commutes_( - self, other: Any, *, atol: Union[int, float] = 1e-8 + self, other: Any, *, atol: float = 1e-8 ) -> Union[bool, NotImplementedType, None]: """Determine if this Operation commutes with the object""" if not isinstance(other, Operation): @@ -771,7 +771,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: return protocols.unitary(self.sub_operation, NotImplemented) def _commutes_( - self, other: Any, *, atol: Union[int, float] = 1e-8 + self, other: Any, *, atol: float = 1e-8 ) -> Union[bool, NotImplementedType, None]: return protocols.commutes(self.sub_operation, other, atol=atol) diff --git a/cirq-core/cirq/protocols/commutes_protocol.py b/cirq-core/cirq/protocols/commutes_protocol.py index 9e295229d79..73ea8d2a40f 100644 --- a/cirq-core/cirq/protocols/commutes_protocol.py +++ b/cirq-core/cirq/protocols/commutes_protocol.py @@ -35,7 +35,7 @@ class SupportsCommutes(Protocol): """An object that can determine commutation relationships vs others.""" @doc_private - def _commutes_(self, other: Any, atol: float) -> Union[None, bool, NotImplementedType]: + def _commutes_(self, other: Any, *, atol: float) -> Union[None, bool, NotImplementedType]: r"""Determines if this object commutes with the other object. Can return None to indicate the commutation relationship is From c630865a75b147037bad72729c2088e89fd6df15 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Thu, 7 Apr 2022 11:35:51 -0700 Subject: [PATCH 2/2] Add default atol in raw_types. --- cirq-core/cirq/ops/common_gates.py | 2 +- cirq-core/cirq/ops/raw_types.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/ops/common_gates.py b/cirq-core/cirq/ops/common_gates.py index 7f836e196d5..9844a7b1f50 100644 --- a/cirq-core/cirq/ops/common_gates.py +++ b/cirq-core/cirq/ops/common_gates.py @@ -662,7 +662,7 @@ def __repr__(self) -> str: ) def _commutes_on_qids_( - self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float + self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float = 1e-8 ) -> Union[bool, NotImplementedType, None]: from cirq.ops.parity_gates import ZZPowGate diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index fb25359d5f0..c4353170b74 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -399,11 +399,13 @@ def _qid_shape_(self) -> Tuple[int, ...]: """ def _commutes_on_qids_( - self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float + self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float = 1e-8 ) -> Union[bool, NotImplementedType, None]: return NotImplemented - def _commutes_(self, other: Any, *, atol: float) -> Union[None, NotImplementedType, bool]: + def _commutes_( + self, other: Any, *, atol: float = 1e-8 + ) -> Union[None, NotImplementedType, bool]: if not isinstance(other, Gate): return NotImplemented if protocols.qid_shape(self) != protocols.qid_shape(other):