Skip to content

Add __cirq_debug__ flag and conditionally disable qid validations in gates and operations #6000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"environment_type": "virtualenv",
"show_commit_url": "https://github.com/quantumlib/Cirq/commit/",
"pythons": ["3.8"],
"matrix": {"env_nobuild": {"PYTHONOPTIMIZE": ["-O", ""]}},
"benchmark_dir": "benchmarks",
"env_dir": ".asv/env",
"results_dir": ".asv/results",
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from cirq import _import

from cirq._compat import __cirq_debug__

# A module can only depend on modules imported earlier in this list of modules
# at import time. Pytest will fail otherwise (enforced by
# dev_tools/import_test.py).
Expand Down
11 changes: 11 additions & 0 deletions cirq-core/cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Workarounds for compatibility issues between versions and libraries."""
import contextlib
import dataclasses

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove blank line and sort imports since contextvars is part of the stdlib.

import contextvars
import functools
import importlib
import inspect
Expand All @@ -31,8 +33,17 @@
import sympy
import sympy.printing.repr

from cirq._doc import document

ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'

__cirq_debug__ = contextvars.ContextVar('__cirq_debug__', default=__debug__)
document(
__cirq_debug__,
"A cirq specific flag which can be used to conditionally turn off all validations across Cirq "
"to boost performance in production mode. Defaults to python's built-in constant __debug__. "
"The flag is implemented as a `ContextVar` and is thread safe.",
)

try:
from functools import cached_property # pylint: disable=unused-import
Expand Down
44 changes: 31 additions & 13 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
import functools
from typing import (
cast,
AbstractSet,
Any,
Callable,
Expand All @@ -40,6 +41,7 @@

from cirq import protocols, value
from cirq._import import LazyLoader
from cirq._compat import __cirq_debug__
from cirq.type_workarounds import NotImplementedType
from cirq.ops import control_values as cv

Expand Down Expand Up @@ -215,7 +217,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']) -> None:
Raises:
ValueError: The gate can't be applied to the qubits.
"""
_validate_qid_shape(self, qubits)
if __cirq_debug__.get():
_validate_qid_shape(self, qubits)

def on(self, *qubits: Qid) -> 'Operation':
"""Returns an application of this gate to the given qubits.
Expand Down Expand Up @@ -254,19 +257,33 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation']
raise TypeError(f'{targets[0]} object is not iterable.')
t0 = list(targets[0])
iterator = [t0] if t0 and isinstance(t0[0], Qid) else t0
for target in iterator:
if not isinstance(target, Sequence):
raise ValueError(
f'Inputs to multi-qubit gates must be Sequence[Qid].'
f' Type: {type(target)}'
)
if not all(isinstance(x, Qid) for x in target):
raise ValueError(f'All values in sequence should be Qids, but got {target}')
if len(target) != self._num_qubits_():
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
operations.append(self.on(*target))
if __cirq_debug__.get():
for target in iterator:
if not isinstance(target, Sequence):
raise ValueError(
f'Inputs to multi-qubit gates must be Sequence[Qid].'
f' Type: {type(target)}'
)
if not all(isinstance(x, Qid) for x in target):
raise ValueError(f'All values in sequence should be Qids, but got {target}')
if len(target) != self._num_qubits_():
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
operations.append(self.on(*target))
else:
operations = [self.on(*target) for target in iterator]
return operations

if __cirq_debug__.get() is False:
return [
op
for q in targets
for op in (
self.on_each(*q)
if isinstance(q, Iterable) and not isinstance(q, str)
else [self.on(cast('cirq.Qid', q))]
)
]

for target in targets:
if isinstance(target, Qid):
operations.append(self.on(target))
Expand Down Expand Up @@ -617,7 +634,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']):
Raises:
ValueError: The operation had qids that don't match it's qid shape.
"""
_validate_qid_shape(self, qubits)
if __cirq_debug__.get():
_validate_qid_shape(self, qubits)

def _commutes_(
self, other: Any, *, atol: float = 1e-8
Expand Down
34 changes: 34 additions & 0 deletions cirq-core/cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,30 @@ def test_op_validate():
op2.validate_args([cirq.LineQid(1, 2), cirq.LineQid(1, 2)])


def test_disable_op_validation():
q0, q1 = cirq.LineQubit.range(2)
h_op = cirq.H(q0)

# Fails normally.
with pytest.raises(ValueError, match='Wrong number'):
_ = cirq.H(q0, q1)
with pytest.raises(ValueError, match='Wrong number'):
h_op.validate_args([q0, q1])

# Passes, skipping validation.
cirq.__cirq_debug__.set(False)
op = cirq.H(q0, q1)
assert op.qubits == (q0, q1)
h_op.validate_args([q0, q1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be nice to add a context manager for changing the __cirq_debug__ flag and making that the default way of using it. This ensures that state is restored even if an exception happens, which makes it much easier to know the state of the flag when looking at code.

@contextlib.contextmanager
def with_debug(value: bool) -> Iterator[None]:
    token = __cirq_debug__.set(value)
    try:
        yield
    finally:
        token.reset()

Then can use that in these tests:

with cirq.with_debug(False):
    op = cirq.H(q0, q1)
    assert op.qubits == (q0, q1)
    h_op.valudate_args([q0, q1])

I think we should encourage using a mechanism like this and discourage setting cirq.__cirq_debug__ manually (other than maybe for interactive use).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding the context manager is fine, but I'd still like to set the default value to __debug__ so that users have a way to disable validations in existing codebase without major modifications in the code.

What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a context manager, PTAL


# Fails again when validation is re-enabled.
cirq.__cirq_debug__.set(True)
with pytest.raises(ValueError, match='Wrong number'):
_ = cirq.H(q0, q1)
with pytest.raises(ValueError, match='Wrong number'):
h_op.validate_args([q0, q1])


def test_default_validation_and_inverse():
class TestGate(cirq.Gate):
def _num_qubits_(self):
Expand Down Expand Up @@ -787,6 +811,11 @@ def matrix(self):
test_non_qubits = [str(i) for i in range(3)]
with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)

cirq.__cirq_debug__.set(False)
assert g.on_each(*test_non_qubits)[0].qubits == ('0',)

cirq.__cirq_debug__.set(True)
with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)

Expand Down Expand Up @@ -853,6 +882,11 @@ def test_on_each_two_qubits():
g.on_each([(a,)])
with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each([(a, b, a)])

cirq.__cirq_debug__.set(False)
assert g.on_each([(a, b, a)])[0].qubits == (a, b, a)
cirq.__cirq_debug__.set(True)

with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each(zip([a, a]))
with pytest.raises(ValueError, match='Expected 2 qubits'):
Expand Down
1 change: 1 addition & 0 deletions cirq-core/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# functools.cached_property was introduced in python 3.8
backports.cached_property~=1.0.1; python_version < '3.8'

contextvars
duet~=0.2.7
matplotlib~=3.0
networkx~=2.4
Expand Down