diff --git a/cirq-core/cirq/ops/common_channels.py b/cirq-core/cirq/ops/common_channels.py index 183973218e0..22f8f8ff3f3 100644 --- a/cirq-core/cirq/ops/common_channels.py +++ b/cirq-core/cirq/ops/common_channels.py @@ -27,7 +27,7 @@ import cirq -@value.value_equality +@value.value_equality(approximate=True) class AsymmetricDepolarizingChannel(raw_types.Gate): r"""A channel that depolarizes asymmetrically along different directions. @@ -196,11 +196,6 @@ def error_probabilities(self) -> Dict[str, float]: def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['error_probabilities']) - def _approx_eq_(self, other: Any, atol: float) -> bool: - self_keys, self_values = zip(*sorted(self.error_probabilities.items())) - other_keys, other_values = zip(*sorted(other.error_probabilities.items())) - return self_keys == other_keys and protocols.approx_eq(self_values, other_values, atol=atol) - def asymmetric_depolarize( p_x: Optional[float] = None, @@ -246,7 +241,7 @@ def asymmetric_depolarize( return AsymmetricDepolarizingChannel(p_x, p_y, p_z, error_probabilities, tol) -@value.value_equality +@value.value_equality(approximate=True) class DepolarizingChannel(raw_types.Gate): r"""A channel that depolarizes one or several qubits. @@ -306,7 +301,7 @@ def _has_mixture_(self) -> bool: return True def _value_equality_values_(self): - return self._p + return self._p, self._n_qubits def __repr__(self) -> str: if self._n_qubits == 1: @@ -347,9 +342,6 @@ def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['p']) return protocols.obj_to_dict_helper(self, ['p', 'n_qubits']) - def _approx_eq_(self, other: Any, atol: float) -> bool: - return np.isclose(self.p, other.p, atol=atol).item() and self.n_qubits == other.n_qubits - def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel: r"""Returns a DepolarizingChannel with given probability of error. @@ -381,7 +373,7 @@ def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel: return DepolarizingChannel(p, n_qubits) -@value.value_equality +@value.value_equality(approximate=True) class GeneralizedAmplitudeDampingChannel(raw_types.Gate): r"""Dampen qubit amplitudes through non ideal dissipation. @@ -489,12 +481,6 @@ def gamma(self) -> float: def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['p', 'gamma']) - def _approx_eq_(self, other: Any, atol: float) -> bool: - return ( - np.isclose(self.gamma, other.gamma, atol=atol).item() - and np.isclose(self.p, other.p, atol=atol).item() - ) - def generalized_amplitude_damp(p: float, gamma: float) -> GeneralizedAmplitudeDampingChannel: r"""Returns a GeneralizedAmplitudeDampingChannel with probabilities gamma and p. @@ -542,7 +528,7 @@ def generalized_amplitude_damp(p: float, gamma: float) -> GeneralizedAmplitudeDa return GeneralizedAmplitudeDampingChannel(p, gamma) -@value.value_equality +@value.value_equality(approximate=True) class AmplitudeDampingChannel(raw_types.Gate): r"""Dampen qubit amplitudes through dissipation. @@ -619,9 +605,6 @@ def gamma(self) -> float: def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['gamma']) - def _approx_eq_(self, other: Any, atol: float) -> bool: - return np.isclose(self.gamma, other.gamma, atol=atol).item() - def amplitude_damp(gamma: float) -> AmplitudeDampingChannel: r"""Returns an AmplitudeDampingChannel with the given probability gamma. @@ -787,7 +770,7 @@ def reset_each(*qubits: 'cirq.Qid') -> List[raw_types.Operation]: return [ResetChannel(q.dimension).on(q) for q in qubits] -@value.value_equality +@value.value_equality(approximate=True) class PhaseDampingChannel(raw_types.Gate): r"""Dampen qubit phase. @@ -881,9 +864,6 @@ def gamma(self) -> float: def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['gamma']) - def _approx_eq_(self, other: Any, atol: float) -> bool: - return np.isclose(self._gamma, other._gamma, atol=atol).item() - def phase_damp(gamma: float) -> PhaseDampingChannel: r"""Creates a PhaseDampingChannel with damping constant gamma. @@ -919,7 +899,7 @@ def phase_damp(gamma: float) -> PhaseDampingChannel: return PhaseDampingChannel(gamma) -@value.value_equality +@value.value_equality(approximate=True) class PhaseFlipChannel(raw_types.Gate): r"""Probabilistically flip the sign of the phase of a qubit. @@ -991,9 +971,6 @@ def p(self) -> float: def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['p']) - def _approx_eq_(self, other: Any, atol: float) -> bool: - return np.isclose(self.p, other.p, atol=atol).item() - def _phase_flip_Z() -> common_gates.ZPowGate: """Returns a cirq.Z which corresponds to a guaranteed phase flip.""" @@ -1073,7 +1050,7 @@ def phase_flip(p: Optional[float] = None) -> Union[common_gates.ZPowGate, PhaseF return _phase_flip(p) -@value.value_equality +@value.value_equality(approximate=True) class BitFlipChannel(raw_types.Gate): r"""Probabilistically flip a qubit from 1 to 0 state or vice versa. @@ -1148,9 +1125,6 @@ def p(self) -> float: def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['p']) - def _approx_eq_(self, other: Any, atol: float) -> bool: - return np.isclose(self._p, other._p, atol=atol).item() - def _bit_flip(p: float) -> BitFlipChannel: r"""Construct a BitFlipChannel that flips a qubit state with probability of a flip given by p. diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index 45bb930921b..f1fdc43bf37 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -92,6 +92,7 @@ def test_asymmetric_depolarizing_channel_eq(): c = cirq.asymmetric_depolarize(0.0, 0.0, 0.0) assert cirq.approx_eq(a, b, atol=1e-2) + assert not cirq.approx_eq(a, cirq.X) et = cirq.testing.EqualsTester() et.make_equality_group(lambda: c) @@ -276,6 +277,7 @@ def test_depolarizing_channel_eq(): c = cirq.depolarize(0.0) assert cirq.approx_eq(a, b, atol=1e-2) + assert not cirq.approx_eq(a, cirq.X) et = cirq.testing.EqualsTester() @@ -283,6 +285,7 @@ def test_depolarizing_channel_eq(): et.add_equality_group(cirq.depolarize(0.1)) et.add_equality_group(cirq.depolarize(0.9)) et.add_equality_group(cirq.depolarize(1.0)) + et.add_equality_group(cirq.depolarize(1.0, n_qubits=2)) def test_depolarizing_channel_invalid_probability(): @@ -349,6 +352,7 @@ def test_generalized_amplitude_damping_channel_eq(): b = cirq.generalized_amplitude_damp(0.01, 0.0099999) assert cirq.approx_eq(a, b, atol=1e-2) + assert not cirq.approx_eq(a, cirq.X) et = cirq.testing.EqualsTester() c = cirq.generalized_amplitude_damp(0.0, 0.0) @@ -411,6 +415,7 @@ def test_amplitude_damping_channel_eq(): c = cirq.amplitude_damp(0.0) assert cirq.approx_eq(a, b, atol=1e-2) + assert not cirq.approx_eq(a, cirq.X) et = cirq.testing.EqualsTester() et.make_equality_group(lambda: c) @@ -562,6 +567,7 @@ def test_phase_damping_channel_eq(): c = cirq.phase_damp(0.0) assert cirq.approx_eq(a, b, atol=1e-2) + assert not cirq.approx_eq(a, cirq.X) et = cirq.testing.EqualsTester() et.make_equality_group(lambda: c) @@ -636,6 +642,7 @@ def test_phase_flip_channel_eq(): c = cirq.phase_flip(0.0) assert cirq.approx_eq(a, b, atol=1e-2) + assert not cirq.approx_eq(a, cirq.X) et = cirq.testing.EqualsTester() et.make_equality_group(lambda: c) @@ -701,6 +708,7 @@ def test_bit_flip_channel_eq(): c = cirq.bit_flip(0.0) assert cirq.approx_eq(a, b, atol=1e-2) + assert not cirq.approx_eq(a, cirq.X) et = cirq.testing.EqualsTester() et.make_equality_group(lambda: c) @@ -834,6 +842,8 @@ def test_multi_asymmetric_depolarizing_eq(): assert cirq.approx_eq(a, b, atol=1e-3) + assert not cirq.approx_eq(a, cirq.X) + def test_multi_asymmetric_depolarizing_channel_str(): assert str(cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})) == (