Skip to content

correct noise channel approximate equality #6632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 18, 2025
42 changes: 8 additions & 34 deletions cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import cirq


@value.value_equality
@value.value_equality(approximate=True)
class AsymmetricDepolarizingChannel(raw_types.Gate):
r"""A channel that depolarizes asymmetrically along different directions.

Expand Down Expand Up @@ -196,11 +196,6 @@ def error_probabilities(self) -> Dict[str, float]:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['error_probabilities'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
self_keys, self_values = zip(*sorted(self.error_probabilities.items()))
other_keys, other_values = zip(*sorted(other.error_probabilities.items()))
return self_keys == other_keys and protocols.approx_eq(self_values, other_values, atol=atol)


def asymmetric_depolarize(
p_x: Optional[float] = None,
Expand Down Expand Up @@ -246,7 +241,7 @@ def asymmetric_depolarize(
return AsymmetricDepolarizingChannel(p_x, p_y, p_z, error_probabilities, tol)


@value.value_equality
@value.value_equality(approximate=True)
class DepolarizingChannel(raw_types.Gate):
r"""A channel that depolarizes one or several qubits.

Expand Down Expand Up @@ -306,7 +301,7 @@ def _has_mixture_(self) -> bool:
return True

def _value_equality_values_(self):
return self._p
return self._p, self._n_qubits

def __repr__(self) -> str:
if self._n_qubits == 1:
Expand Down Expand Up @@ -347,9 +342,6 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['p'])
return protocols.obj_to_dict_helper(self, ['p', 'n_qubits'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self.p, other.p, atol=atol).item() and self.n_qubits == other.n_qubits


def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel:
r"""Returns a DepolarizingChannel with given probability of error.
Expand Down Expand Up @@ -381,7 +373,7 @@ def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel:
return DepolarizingChannel(p, n_qubits)


@value.value_equality
@value.value_equality(approximate=True)
class GeneralizedAmplitudeDampingChannel(raw_types.Gate):
r"""Dampen qubit amplitudes through non ideal dissipation.

Expand Down Expand Up @@ -489,12 +481,6 @@ def gamma(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['p', 'gamma'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return (
np.isclose(self.gamma, other.gamma, atol=atol).item()
and np.isclose(self.p, other.p, atol=atol).item()
)


def generalized_amplitude_damp(p: float, gamma: float) -> GeneralizedAmplitudeDampingChannel:
r"""Returns a GeneralizedAmplitudeDampingChannel with probabilities gamma and p.
Expand Down Expand Up @@ -542,7 +528,7 @@ def generalized_amplitude_damp(p: float, gamma: float) -> GeneralizedAmplitudeDa
return GeneralizedAmplitudeDampingChannel(p, gamma)


@value.value_equality
@value.value_equality(approximate=True)
class AmplitudeDampingChannel(raw_types.Gate):
r"""Dampen qubit amplitudes through dissipation.

Expand Down Expand Up @@ -619,9 +605,6 @@ def gamma(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gamma'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self.gamma, other.gamma, atol=atol).item()


def amplitude_damp(gamma: float) -> AmplitudeDampingChannel:
r"""Returns an AmplitudeDampingChannel with the given probability gamma.
Expand Down Expand Up @@ -787,7 +770,7 @@ def reset_each(*qubits: 'cirq.Qid') -> List[raw_types.Operation]:
return [ResetChannel(q.dimension).on(q) for q in qubits]


@value.value_equality
@value.value_equality(approximate=True)
class PhaseDampingChannel(raw_types.Gate):
r"""Dampen qubit phase.

Expand Down Expand Up @@ -881,9 +864,6 @@ def gamma(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gamma'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self._gamma, other._gamma, atol=atol).item()


def phase_damp(gamma: float) -> PhaseDampingChannel:
r"""Creates a PhaseDampingChannel with damping constant gamma.
Expand Down Expand Up @@ -919,7 +899,7 @@ def phase_damp(gamma: float) -> PhaseDampingChannel:
return PhaseDampingChannel(gamma)


@value.value_equality
@value.value_equality(approximate=True)
class PhaseFlipChannel(raw_types.Gate):
r"""Probabilistically flip the sign of the phase of a qubit.

Expand Down Expand Up @@ -991,9 +971,6 @@ def p(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['p'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self.p, other.p, atol=atol).item()


def _phase_flip_Z() -> common_gates.ZPowGate:
"""Returns a cirq.Z which corresponds to a guaranteed phase flip."""
Expand Down Expand Up @@ -1073,7 +1050,7 @@ def phase_flip(p: Optional[float] = None) -> Union[common_gates.ZPowGate, PhaseF
return _phase_flip(p)


@value.value_equality
@value.value_equality(approximate=True)
class BitFlipChannel(raw_types.Gate):
r"""Probabilistically flip a qubit from 1 to 0 state or vice versa.

Expand Down Expand Up @@ -1148,9 +1125,6 @@ def p(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['p'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self._p, other._p, atol=atol).item()


def _bit_flip(p: float) -> BitFlipChannel:
r"""Construct a BitFlipChannel that flips a qubit state with probability of a flip given by p.
Expand Down
10 changes: 10 additions & 0 deletions cirq-core/cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_asymmetric_depolarizing_channel_eq():
c = cirq.asymmetric_depolarize(0.0, 0.0, 0.0)

assert cirq.approx_eq(a, b, atol=1e-2)
assert not cirq.approx_eq(a, cirq.X)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
Expand Down Expand Up @@ -276,13 +277,15 @@ def test_depolarizing_channel_eq():
c = cirq.depolarize(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)
assert not cirq.approx_eq(a, cirq.X)

et = cirq.testing.EqualsTester()

et.make_equality_group(lambda: c)
et.add_equality_group(cirq.depolarize(0.1))
et.add_equality_group(cirq.depolarize(0.9))
et.add_equality_group(cirq.depolarize(1.0))
et.add_equality_group(cirq.depolarize(1.0, n_qubits=2))


def test_depolarizing_channel_invalid_probability():
Expand Down Expand Up @@ -349,6 +352,7 @@ def test_generalized_amplitude_damping_channel_eq():
b = cirq.generalized_amplitude_damp(0.01, 0.0099999)

assert cirq.approx_eq(a, b, atol=1e-2)
assert not cirq.approx_eq(a, cirq.X)

et = cirq.testing.EqualsTester()
c = cirq.generalized_amplitude_damp(0.0, 0.0)
Expand Down Expand Up @@ -411,6 +415,7 @@ def test_amplitude_damping_channel_eq():
c = cirq.amplitude_damp(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)
assert not cirq.approx_eq(a, cirq.X)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
Expand Down Expand Up @@ -562,6 +567,7 @@ def test_phase_damping_channel_eq():
c = cirq.phase_damp(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)
assert not cirq.approx_eq(a, cirq.X)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
Expand Down Expand Up @@ -636,6 +642,7 @@ def test_phase_flip_channel_eq():
c = cirq.phase_flip(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)
assert not cirq.approx_eq(a, cirq.X)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
Expand Down Expand Up @@ -701,6 +708,7 @@ def test_bit_flip_channel_eq():
c = cirq.bit_flip(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)
assert not cirq.approx_eq(a, cirq.X)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
Expand Down Expand Up @@ -834,6 +842,8 @@ def test_multi_asymmetric_depolarizing_eq():

assert cirq.approx_eq(a, b, atol=1e-3)

assert not cirq.approx_eq(a, cirq.X)


def test_multi_asymmetric_depolarizing_channel_str():
assert str(cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})) == (
Expand Down