Skip to content

Use frozensets for key protocols #5560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,26 +918,30 @@ def qid_shape(
qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
return {key for op in self.all_operations() for key in protocols.measurement_key_objs(op)}
def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
return frozenset(
key for op in self.all_operations() for key in protocols.measurement_key_objs(op)
)

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
"""Returns the set of all measurement keys in this circuit.

Returns: AbstractSet of `cirq.MeasurementKey` objects that are
Returns: FrozenSet of `cirq.MeasurementKey` objects that are
in this circuit.
"""
return self.all_measurement_key_objs()

def all_measurement_key_names(self) -> AbstractSet[str]:
def all_measurement_key_names(self) -> FrozenSet[str]:
"""Returns the set of all measurement key names in this circuit.

Returns: AbstractSet of strings that are the measurement key
Returns: FrozenSet of strings that are the measurement key
names in this circuit.
"""
return {key for op in self.all_operations() for key in protocols.measurement_key_names(op)}
return frozenset(
key for op in self.all_operations() for key in protocols.measurement_key_names(op)
)

def _measurement_key_names_(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> FrozenSet[str]:
return self.all_measurement_key_names()

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
Expand Down
29 changes: 15 additions & 14 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""
import math
from typing import (
AbstractSet,
Callable,
Mapping,
Sequence,
Expand Down Expand Up @@ -309,30 +308,32 @@ def _ensure_deterministic_loop_count(self):
raise ValueError('Cannot unroll circuit due to nondeterministic repetitions')

@cached_property
def _measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
circuit_keys = protocols.measurement_key_objs(self.circuit)
if circuit_keys and self.use_repetition_ids:
self._ensure_deterministic_loop_count()
if self.repetition_ids is not None:
circuit_keys = {
circuit_keys = frozenset(
key.with_key_path_prefix(repetition_id)
for repetition_id in self.repetition_ids
for key in circuit_keys
}
circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys}
return {
)
circuit_keys = frozenset(
key.with_key_path_prefix(*self.parent_path) for key in circuit_keys
)
return frozenset(
protocols.with_measurement_key_mapping(key, dict(self.measurement_key_map))
for key in circuit_keys
}
)

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
return self._measurement_key_objs

def _measurement_key_names_(self) -> AbstractSet[str]:
return {str(key) for key in self._measurement_key_objs_()}
def _measurement_key_names_(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

@cached_property
def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']:
def _control_keys(self) -> FrozenSet['cirq.MeasurementKey']:
keys = (
frozenset()
if not protocols.control_keys(self.circuit)
Expand All @@ -342,13 +343,13 @@ def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']:
keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_()
return keys

def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
return self._control_keys

def _is_parameterized_(self) -> bool:
return any(self._parameter_names_generator())

def _parameter_names_(self) -> AbstractSet[str]:
def _parameter_names_(self) -> FrozenSet[str]:
return frozenset(self._parameter_names_generator())

def _parameter_names_generator(self) -> Iterator[str]:
Expand Down Expand Up @@ -463,7 +464,7 @@ def __str__(self):
)
args = []

def dict_str(d: Dict) -> str:
def dict_str(d: Mapping) -> str:
pairs = [f'{k}: {v}' for k, v in sorted(d.items())]
return '{' + ', '.join(pairs) + '}'

Expand Down
31 changes: 10 additions & 21 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""An immutable version of the Circuit data structure."""
from typing import (
TYPE_CHECKING,
AbstractSet,
FrozenSet,
Iterable,
Iterator,
Optional,
Sequence,
Tuple,
Union,
)
from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Optional, Sequence, Tuple, Union

import numpy as np

from cirq import ops, protocols
from cirq.circuits import AbstractCircuit, Alignment, Circuit
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.type_workarounds import NotImplementedType

import numpy as np

from cirq import ops, protocols, _compat


if TYPE_CHECKING:
import cirq

Expand Down Expand Up @@ -70,7 +59,7 @@ def __init__(
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._all_operations: Optional[Tuple[ops.Operation, ...]] = None
self._has_measurements: Optional[bool] = None
self._all_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None
self._all_measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None
self._are_all_measurements_terminal: Optional[bool] = None
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None

Expand Down Expand Up @@ -118,12 +107,12 @@ def has_measurements(self) -> bool:
self._has_measurements = super().has_measurements()
return self._has_measurements

def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
if self._all_measurement_key_objs is None:
self._all_measurement_key_objs = super().all_measurement_key_objs()
return self._all_measurement_key_objs

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
return self.all_measurement_key_objs()

def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
Expand All @@ -138,10 +127,10 @@ def are_all_measurements_terminal(self) -> bool:

# End of memoized methods.

def all_measurement_key_names(self) -> AbstractSet[str]:
return {str(key) for key in self.all_measurement_key_objs()}
def all_measurement_key_names(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self.all_measurement_key_objs())

def _measurement_key_names_(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> FrozenSet[str]:
return self.all_measurement_key_names()

def __add__(self, other) -> 'cirq.FrozenCircuit':
Expand Down
5 changes: 2 additions & 3 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import itertools
from typing import (
AbstractSet,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -238,8 +237,8 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
for op in self.operations
)

def _measurement_key_names_(self) -> AbstractSet[str]:
return {str(key) for key in self._measurement_key_objs_()}
def _measurement_key_names_(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
if self._measurement_key_objs is None:
Expand Down
6 changes: 4 additions & 2 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _measurement_key_name_(self) -> Optional[str]:
return getter()
return NotImplemented

def _measurement_key_names_(self) -> Optional[AbstractSet[str]]:
def _measurement_key_names_(self) -> Union[FrozenSet[str], NotImplementedType, None]:
getter = getattr(self.gate, '_measurement_key_names_', None)
if getter is not None:
return getter()
Expand All @@ -247,7 +247,9 @@ def _measurement_key_obj_(self) -> Optional['cirq.MeasurementKey']:
return getter()
return NotImplemented

def _measurement_key_objs_(self) -> Optional[AbstractSet['cirq.MeasurementKey']]:
def _measurement_key_objs_(
self,
) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]:
getter = getattr(self.gate, '_measurement_key_objs_', None)
if getter is not None:
return getter()
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,10 @@ def _has_kraus_(self) -> bool:
def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
return protocols.kraus(self.sub_operation, NotImplemented)

def _measurement_key_names_(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> FrozenSet[str]:
return protocols.measurement_key_names(self.sub_operation)

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
return protocols.measurement_key_objs(self.sub_operation)

def _is_measurement_(self) -> bool:
Expand Down Expand Up @@ -905,7 +905,7 @@ def with_classical_controls(
return self
return self.sub_operation.with_classical_controls(*conditions)

def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
return protocols.control_keys(self.sub_operation)


Expand Down
20 changes: 14 additions & 6 deletions cirq-core/cirq/protocols/control_key_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
"""Protocol for object that have control keys."""

from typing import AbstractSet, Any, Iterable, TYPE_CHECKING
from typing import Any, FrozenSet, TYPE_CHECKING, Union

from typing_extensions import Protocol

from cirq import _compat
from cirq._doc import doc_private
from cirq.protocols import measurement_key_protocol
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
import cirq
Expand All @@ -34,7 +36,7 @@ class SupportsControlKey(Protocol):
"""

@doc_private
def _control_keys_(self) -> Iterable['cirq.MeasurementKey']:
def _control_keys_(self) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]:
"""Return the keys for controls referenced by the receiving object.

Returns:
Expand All @@ -43,7 +45,7 @@ def _control_keys_(self) -> Iterable['cirq.MeasurementKey']:
"""


def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']:
def control_keys(val: Any) -> FrozenSet['cirq.MeasurementKey']:
"""Gets the keys that the value is classically controlled by.

Args:
Expand All @@ -56,12 +58,18 @@ def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']:
getter = getattr(val, '_control_keys_', None)
result = NotImplemented if getter is None else getter()
if result is not NotImplemented and result is not None:
return set(result)
if not isinstance(result, FrozenSet):
_compat._warn_or_error(
f'The _control_keys_ implementation of {type(val)} must return a'
f' frozenset instead of {type(result)} by v0.16.'
)
return frozenset(result)
return result

return set()
return frozenset()


def measurement_keys_touched(val: Any) -> AbstractSet['cirq.MeasurementKey']:
def measurement_keys_touched(val: Any) -> FrozenSet['cirq.MeasurementKey']:
"""Returns all the measurement keys used by the value.

This would be the case if the value is or contains a measurement gate, or
Expand Down
11 changes: 10 additions & 1 deletion cirq-core/cirq/protocols/control_key_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_control_key():
class Named:
def _control_keys_(self):
return [cirq.MeasurementKey('key')]
return frozenset([cirq.MeasurementKey('key')])

class NoImpl:
def _control_keys_(self):
Expand All @@ -27,3 +27,12 @@ def _control_keys_(self):
assert cirq.control_keys(Named()) == {cirq.MeasurementKey('key')}
assert not cirq.control_keys(NoImpl())
assert not cirq.control_keys(5)


def test_control_key_enumerable_deprecated():
class Deprecated:
def _control_keys_(self):
return [cirq.MeasurementKey('key')]

with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'):
assert cirq.control_keys(Deprecated()) == {cirq.MeasurementKey('key')}
Loading