diff --git a/cirq-core/cirq/ops/identity.py b/cirq-core/cirq/ops/identity.py index c13d8208620..361e70c4ea7 100644 --- a/cirq-core/cirq/ops/identity.py +++ b/cirq-core/cirq/ops/identity.py @@ -13,13 +13,14 @@ # limitations under the License. """IdentityGate.""" -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Sequence +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Sequence, Union import numpy as np import sympy from cirq import protocols, value from cirq._doc import document +from cirq.type_workarounds import NotImplementedType from cirq.ops import raw_types if TYPE_CHECKING: @@ -75,6 +76,12 @@ def __pow__(self, power: Any) -> Any: return self return NotImplemented + def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]: + """The identity gate commutes with all other gates.""" + if not isinstance(other, raw_types.Gate): + return NotImplemented + return True + def _has_unitary_(self) -> bool: return True diff --git a/cirq-core/cirq/ops/identity_test.py b/cirq-core/cirq/ops/identity_test.py index a3ce014b293..11a2a21a024 100644 --- a/cirq-core/cirq/ops/identity_test.py +++ b/cirq-core/cirq/ops/identity_test.py @@ -208,3 +208,9 @@ def test_identity_short_circuits_act_on(): args = mock.Mock(cirq.SimulationState) args._act_on_fallback_.side_effect = mock.Mock(side_effect=Exception('No!')) cirq.act_on(cirq.IdentityGate(1)(cirq.LineQubit(0)), args) + + +def test_identity_commutes(): + assert cirq.commutes(cirq.I, cirq.X) + with pytest.raises(TypeError): + cirq.commutes(cirq.I, "Gate")