Skip to content

Make _commutes_ consistent #5217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,7 @@ def cleanup_key(key: Any) -> Any:

return diagram.render()

def _commutes_(
self, other: Any, *, atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType]:
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
"""Determines whether Moment commutes with the Operation.

Args:
Expand Down
4 changes: 1 addition & 3 deletions cirq-core/cirq/contrib/acquaintance/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
return (self.swap_gate,)

def _commutes_(
self, other: Any, atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType]:
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
if (
isinstance(other, ops.Gate)
and isinstance(other, ops.InterchangeableQubitsGate)
Expand Down
6 changes: 4 additions & 2 deletions cirq-core/cirq/ops/clifford_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __pow__(self, exponent) -> 'SingleQubitCliffordGate':

return SingleQubitCliffordGate.from_clifford_tableau(self.clifford_tableau.inverse())

def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType]:
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
if isinstance(other, SingleQubitCliffordGate):
return self.commutes_with_single_qubit_gate(other)
if isinstance(other, Pauli):
Expand Down Expand Up @@ -838,7 +838,9 @@ def __pow__(self, exponent) -> 'CliffordGate':
def __repr__(self) -> str:
return f"Clifford Gate with Tableau:\n {self.clifford_tableau._str_full_()}"

def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
def _commutes_(
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
# Note even if we assume two gates define the tabluea based on the same qubit order,
# the following approach cannot judge it:
# self.clifford_tableau.then(other.clifford_tableau) == other.clifford_tableau.then(
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def __repr__(self) -> str:
)

def _commutes_on_qids_(
self, qids: 'Sequence[cirq.Qid]', other: Any, atol: float
self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float
) -> Union[bool, NotImplementedType, None]:
from cirq.ops.parity_gates import ZZPowGate

Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,9 @@ def __repr__(self) -> str:
f'coefficient={proper_repr(self.coefficient)})'
)

def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
def _commutes_(
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
if isinstance(other, BaseDensePauliString):
n = min(len(self.pauli_mask), len(other.pauli_mask))
phase = _vectorized_pauli_mul_phase(self.pauli_mask[:n], other.pauli_mask[:n])
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return NotImplemented

def _commutes_(
self, other: Any, atol: Union[int, float] = 1e-8
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
commutes = self.gate._commutes_on_qids_(self.qubits, other, atol=atol)
if commutes is not NotImplemented:
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/pauli_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def __init__(self, index: int, name: str) -> None:
def num_qubits(self):
return 1

def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
def _commutes_(
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
if not isinstance(other, Pauli):
return NotImplemented
return self is other
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def zip_paulis(
return (paulis for qubit, paulis in self.zip_items(other))

def _commutes_(
self, other: Any, *, atol: Union[int, float] = 1e-8
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
if not isinstance(other, PauliString):
return NotImplemented
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,11 +399,11 @@ def _qid_shape_(self) -> Tuple[int, ...]:
"""

def _commutes_on_qids_(
self, qids: 'Sequence[cirq.Qid]', other: Any, atol: float
self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason not to set a default value here (and one directly below) as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. I thought it might be too aggressive to put it in raw-types, but I think it may be good to have a global default, so I added it in on your suggestion.

) -> Union[bool, NotImplementedType, None]:
return NotImplemented

def _commutes_(self, other: Any, atol: float) -> Union[None, NotImplementedType, bool]:
def _commutes_(self, other: Any, *, atol: float) -> Union[None, NotImplementedType, bool]:
if not isinstance(other, Gate):
return NotImplemented
if protocols.qid_shape(self) != protocols.qid_shape(other):
Expand Down Expand Up @@ -567,7 +567,7 @@ def validate_args(self, qubits: Sequence['cirq.Qid']):
_validate_qid_shape(self, qubits)

def _commutes_(
self, other: Any, *, atol: Union[int, float] = 1e-8
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
"""Determine if this Operation commutes with the object"""
if not isinstance(other, Operation):
Expand Down Expand Up @@ -771,7 +771,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return protocols.unitary(self.sub_operation, NotImplemented)

def _commutes_(
self, other: Any, *, atol: Union[int, float] = 1e-8
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
return protocols.commutes(self.sub_operation, other, atol=atol)

Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/commutes_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SupportsCommutes(Protocol):
"""An object that can determine commutation relationships vs others."""

@doc_private
def _commutes_(self, other: Any, atol: float) -> Union[None, bool, NotImplementedType]:
def _commutes_(self, other: Any, *, atol: float) -> Union[None, bool, NotImplementedType]:
r"""Determines if this object commutes with the other object.

Can return None to indicate the commutation relationship is
Expand Down