diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 917b422ad7f..43e2b31427c 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -918,26 +918,30 @@ def qid_shape( qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits()) return protocols.qid_shape(qids) - def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: - return {key for op in self.all_operations() for key in protocols.measurement_key_objs(op)} + def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: + return frozenset( + key for op in self.all_operations() for key in protocols.measurement_key_objs(op) + ) - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: """Returns the set of all measurement keys in this circuit. - Returns: AbstractSet of `cirq.MeasurementKey` objects that are + Returns: FrozenSet of `cirq.MeasurementKey` objects that are in this circuit. """ return self.all_measurement_key_objs() - def all_measurement_key_names(self) -> AbstractSet[str]: + def all_measurement_key_names(self) -> FrozenSet[str]: """Returns the set of all measurement key names in this circuit. - Returns: AbstractSet of strings that are the measurement key + Returns: FrozenSet of strings that are the measurement key names in this circuit. """ - return {key for op in self.all_operations() for key in protocols.measurement_key_names(op)} + return frozenset( + key for op in self.all_operations() for key in protocols.measurement_key_names(op) + ) - def _measurement_key_names_(self) -> AbstractSet[str]: + def _measurement_key_names_(self) -> FrozenSet[str]: return self.all_measurement_key_names() def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index bd1c104d998..46036798d9b 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -19,7 +19,6 @@ """ import math from typing import ( - AbstractSet, Callable, Mapping, Sequence, @@ -309,30 +308,32 @@ def _ensure_deterministic_loop_count(self): raise ValueError('Cannot unroll circuit due to nondeterministic repetitions') @cached_property - def _measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: circuit_keys = protocols.measurement_key_objs(self.circuit) if circuit_keys and self.use_repetition_ids: self._ensure_deterministic_loop_count() if self.repetition_ids is not None: - circuit_keys = { + circuit_keys = frozenset( key.with_key_path_prefix(repetition_id) for repetition_id in self.repetition_ids for key in circuit_keys - } - circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys} - return { + ) + circuit_keys = frozenset( + key.with_key_path_prefix(*self.parent_path) for key in circuit_keys + ) + return frozenset( protocols.with_measurement_key_mapping(key, dict(self.measurement_key_map)) for key in circuit_keys - } + ) - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: return self._measurement_key_objs - def _measurement_key_names_(self) -> AbstractSet[str]: - return {str(key) for key in self._measurement_key_objs_()} + def _measurement_key_names_(self) -> FrozenSet[str]: + return frozenset(str(key) for key in self._measurement_key_objs_()) @cached_property - def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']: + def _control_keys(self) -> FrozenSet['cirq.MeasurementKey']: keys = ( frozenset() if not protocols.control_keys(self.circuit) @@ -342,13 +343,13 @@ def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']: keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_() return keys - def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: return self._control_keys def _is_parameterized_(self) -> bool: return any(self._parameter_names_generator()) - def _parameter_names_(self) -> AbstractSet[str]: + def _parameter_names_(self) -> FrozenSet[str]: return frozenset(self._parameter_names_generator()) def _parameter_names_generator(self) -> Iterator[str]: @@ -463,7 +464,7 @@ def __str__(self): ) args = [] - def dict_str(d: Dict) -> str: + def dict_str(d: Mapping) -> str: pairs = [f'{k}: {v}' for k, v in sorted(d.items())] return '{' + ', '.join(pairs) + '}' diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index b6cb08bada0..75d56a9c3e9 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -12,27 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """An immutable version of the Circuit data structure.""" -from typing import ( - TYPE_CHECKING, - AbstractSet, - FrozenSet, - Iterable, - Iterator, - Optional, - Sequence, - Tuple, - Union, -) +from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Optional, Sequence, Tuple, Union +import numpy as np + +from cirq import ops, protocols from cirq.circuits import AbstractCircuit, Alignment, Circuit from cirq.circuits.insert_strategy import InsertStrategy from cirq.type_workarounds import NotImplementedType - -import numpy as np - from cirq import ops, protocols, _compat - if TYPE_CHECKING: import cirq @@ -70,7 +59,7 @@ def __init__( self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None self._all_operations: Optional[Tuple[ops.Operation, ...]] = None self._has_measurements: Optional[bool] = None - self._all_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None + self._all_measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None self._are_all_measurements_terminal: Optional[bool] = None self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None @@ -118,12 +107,12 @@ def has_measurements(self) -> bool: self._has_measurements = super().has_measurements() return self._has_measurements - def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: + def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: if self._all_measurement_key_objs is None: self._all_measurement_key_objs = super().all_measurement_key_objs() return self._all_measurement_key_objs - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: return self.all_measurement_key_objs() def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: @@ -138,10 +127,10 @@ def are_all_measurements_terminal(self) -> bool: # End of memoized methods. - def all_measurement_key_names(self) -> AbstractSet[str]: - return {str(key) for key in self.all_measurement_key_objs()} + def all_measurement_key_names(self) -> FrozenSet[str]: + return frozenset(str(key) for key in self.all_measurement_key_objs()) - def _measurement_key_names_(self) -> AbstractSet[str]: + def _measurement_key_names_(self) -> FrozenSet[str]: return self.all_measurement_key_names() def __add__(self, other) -> 'cirq.FrozenCircuit': diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index f2179a7c635..8fe7cf35c91 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -16,7 +16,6 @@ import itertools from typing import ( - AbstractSet, Any, Callable, Dict, @@ -238,8 +237,8 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): for op in self.operations ) - def _measurement_key_names_(self) -> AbstractSet[str]: - return {str(key) for key in self._measurement_key_objs_()} + def _measurement_key_names_(self) -> FrozenSet[str]: + return frozenset(str(key) for key in self._measurement_key_objs_()) def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: if self._measurement_key_objs is None: diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 95e69be6049..f17612b2c49 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -235,7 +235,7 @@ def _measurement_key_name_(self) -> Optional[str]: return getter() return NotImplemented - def _measurement_key_names_(self) -> Optional[AbstractSet[str]]: + def _measurement_key_names_(self) -> Union[FrozenSet[str], NotImplementedType, None]: getter = getattr(self.gate, '_measurement_key_names_', None) if getter is not None: return getter() @@ -247,7 +247,9 @@ def _measurement_key_obj_(self) -> Optional['cirq.MeasurementKey']: return getter() return NotImplemented - def _measurement_key_objs_(self) -> Optional[AbstractSet['cirq.MeasurementKey']]: + def _measurement_key_objs_( + self, + ) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]: getter = getattr(self.gate, '_measurement_key_objs_', None) if getter is not None: return getter() diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 63955a36b86..22b1c6b4839 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -820,10 +820,10 @@ def _has_kraus_(self) -> bool: def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]: return protocols.kraus(self.sub_operation, NotImplemented) - def _measurement_key_names_(self) -> AbstractSet[str]: + def _measurement_key_names_(self) -> FrozenSet[str]: return protocols.measurement_key_names(self.sub_operation) - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: return protocols.measurement_key_objs(self.sub_operation) def _is_measurement_(self) -> bool: @@ -905,7 +905,7 @@ def with_classical_controls( return self return self.sub_operation.with_classical_controls(*conditions) - def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: return protocols.control_keys(self.sub_operation) diff --git a/cirq-core/cirq/protocols/control_key_protocol.py b/cirq-core/cirq/protocols/control_key_protocol.py index ef39362eed7..f8897734918 100644 --- a/cirq-core/cirq/protocols/control_key_protocol.py +++ b/cirq-core/cirq/protocols/control_key_protocol.py @@ -13,12 +13,14 @@ # limitations under the License. """Protocol for object that have control keys.""" -from typing import AbstractSet, Any, Iterable, TYPE_CHECKING +from typing import Any, FrozenSet, TYPE_CHECKING, Union from typing_extensions import Protocol +from cirq import _compat from cirq._doc import doc_private from cirq.protocols import measurement_key_protocol +from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: import cirq @@ -34,7 +36,7 @@ class SupportsControlKey(Protocol): """ @doc_private - def _control_keys_(self) -> Iterable['cirq.MeasurementKey']: + def _control_keys_(self) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]: """Return the keys for controls referenced by the receiving object. Returns: @@ -43,7 +45,7 @@ def _control_keys_(self) -> Iterable['cirq.MeasurementKey']: """ -def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']: +def control_keys(val: Any) -> FrozenSet['cirq.MeasurementKey']: """Gets the keys that the value is classically controlled by. Args: @@ -56,12 +58,18 @@ def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']: getter = getattr(val, '_control_keys_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return set(result) + if not isinstance(result, FrozenSet): + _compat._warn_or_error( + f'The _control_keys_ implementation of {type(val)} must return a' + f' frozenset instead of {type(result)} by v0.16.' + ) + return frozenset(result) + return result - return set() + return frozenset() -def measurement_keys_touched(val: Any) -> AbstractSet['cirq.MeasurementKey']: +def measurement_keys_touched(val: Any) -> FrozenSet['cirq.MeasurementKey']: """Returns all the measurement keys used by the value. This would be the case if the value is or contains a measurement gate, or diff --git a/cirq-core/cirq/protocols/control_key_protocol_test.py b/cirq-core/cirq/protocols/control_key_protocol_test.py index 7abee9cea42..4b72aecf277 100644 --- a/cirq-core/cirq/protocols/control_key_protocol_test.py +++ b/cirq-core/cirq/protocols/control_key_protocol_test.py @@ -18,7 +18,7 @@ def test_control_key(): class Named: def _control_keys_(self): - return [cirq.MeasurementKey('key')] + return frozenset([cirq.MeasurementKey('key')]) class NoImpl: def _control_keys_(self): @@ -27,3 +27,12 @@ def _control_keys_(self): assert cirq.control_keys(Named()) == {cirq.MeasurementKey('key')} assert not cirq.control_keys(NoImpl()) assert not cirq.control_keys(5) + + +def test_control_key_enumerable_deprecated(): + class Deprecated: + def _control_keys_(self): + return [cirq.MeasurementKey('key')] + + with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'): + assert cirq.control_keys(Deprecated()) == {cirq.MeasurementKey('key')} diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index ac7fb637a13..e1164298361 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -13,12 +13,13 @@ # limitations under the License. """Protocol for object that have measurement keys.""" -from typing import AbstractSet, Any, Dict, FrozenSet, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, FrozenSet, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import Protocol -from cirq import value +from cirq import value, _compat from cirq._doc import doc_private +from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: import cirq @@ -68,7 +69,9 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': """ @doc_private - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_( + self, + ) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]: """Return the key objects for measurements performed by the receiving object. When a measurement occurs, either on hardware, or in a simulation, @@ -86,7 +89,7 @@ def _measurement_key_name_(self) -> str: """ @doc_private - def _measurement_key_names_(self) -> AbstractSet[str]: + def _measurement_key_names_(self) -> Union[FrozenSet[str], NotImplementedType, None]: """Return the string keys for measurements performed by the receiving object. When a measurement occurs, either on hardware, or in a simulation, @@ -172,39 +175,53 @@ def measurement_key_name(val: Any, default: Any = RaiseTypeErrorIfNotProvided): def _measurement_key_objs_from_magic_methods( val: Any, -) -> Optional[AbstractSet['cirq.MeasurementKey']]: +) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]: """Uses the measurement key related magic methods to get the `MeasurementKey`s for this object.""" getter = getattr(val, '_measurement_key_objs_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return set(result) + if not isinstance(result, FrozenSet): + _compat._warn_or_error( + f'The _measurement_key_objs_ implementation of {type(val)} must return a' + f' frozenset instead of {type(result)} by v0.16.' + ) + return frozenset(result) + return result getter = getattr(val, '_measurement_key_obj_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return {result} + return frozenset([result]) return result -def _measurement_key_names_from_magic_methods(val: Any) -> Optional[AbstractSet[str]]: +def _measurement_key_names_from_magic_methods( + val: Any, +) -> Union[FrozenSet[str], NotImplementedType, None]: """Uses the measurement key related magic methods to get the key strings for this object.""" getter = getattr(val, '_measurement_key_names_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return set(result) + if not isinstance(result, FrozenSet): + _compat._warn_or_error( + f'The _measurement_key_names_ implementation of {type(val)} must return a' + f' frozenset instead of {type(result)} by v0.16.' + ) + return frozenset(result) + return result getter = getattr(val, '_measurement_key_name_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return {result} + return frozenset([result]) return result -def measurement_key_objs(val: Any) -> AbstractSet['cirq.MeasurementKey']: +def measurement_key_objs(val: Any) -> FrozenSet['cirq.MeasurementKey']: """Gets the measurement key objects of measurements within the given value. Args: @@ -219,11 +236,11 @@ def measurement_key_objs(val: Any) -> AbstractSet['cirq.MeasurementKey']: return result key_strings = _measurement_key_names_from_magic_methods(val) if key_strings is not NotImplemented and key_strings is not None: - return {value.MeasurementKey.parse_serialized(key_str) for key_str in key_strings} - return set() + return frozenset(value.MeasurementKey.parse_serialized(key_str) for key_str in key_strings) + return frozenset() -def measurement_key_names(val: Any) -> AbstractSet[str]: +def measurement_key_names(val: Any) -> FrozenSet[str]: """Gets the measurement key strings of measurements within the given value. Args: @@ -244,8 +261,8 @@ def measurement_key_names(val: Any) -> AbstractSet[str]: return result key_objs = _measurement_key_objs_from_magic_methods(val) if key_objs is not NotImplemented and key_objs is not None: - return {str(key_obj) for key_obj in key_objs} - return set() + return frozenset(str(key_obj) for key_obj in key_objs) + return frozenset() def _is_measurement_from_magic_method(val: Any) -> Optional[bool]: diff --git a/cirq-core/cirq/protocols/measurement_key_protocol_test.py b/cirq-core/cirq/protocols/measurement_key_protocol_test.py index 38398e9b880..9b24681ba26 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol_test.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol_test.py @@ -158,10 +158,10 @@ def num_qubits(self) -> int: def test_measurement_keys(key_method, keys): class MeasurementKeysGate(cirq.Gate): def _measurement_key_names_(self): - return ['a', 'b'] + return frozenset(['a', 'b']) def _measurement_key_objs_(self): - return [cirq.MeasurementKey('c'), cirq.MeasurementKey('d')] + return frozenset([cirq.MeasurementKey('c'), cirq.MeasurementKey('d')]) def num_qubits(self) -> int: return 1 @@ -183,7 +183,7 @@ def num_qubits(self) -> int: def test_measurement_key_mapping(): class MultiKeyGate: def __init__(self, keys): - self._keys = set(keys) + self._keys = frozenset(keys) def _measurement_key_names_(self): return self._keys @@ -220,10 +220,10 @@ def _with_measurement_key_mapping_(self, key_map): def test_measurement_key_path(): class MultiKeyGate: def __init__(self, keys): - self._keys = set([cirq.MeasurementKey.parse_serialized(key) for key in keys]) + self._keys = frozenset(cirq.MeasurementKey.parse_serialized(key) for key in keys) def _measurement_key_names_(self): - return {str(key) for key in self._keys} + return frozenset(str(key) for key in self._keys) def _with_key_path_(self, path): return MultiKeyGate([str(key._with_key_path_(path)) for key in self._keys]) @@ -238,3 +238,18 @@ def _with_key_path_(self, path): assert cirq.measurement_key_names(mkg_cd) == {'c:d:a', 'c:d:b'} assert cirq.with_key_path(cirq.X, ('c', 'd')) is NotImplemented + + +def test_measurement_key_enumerable_deprecated(): + class Deprecated: + def _measurement_key_objs_(self): + return [cirq.MeasurementKey('key')] + + def _measurement_key_names_(self): + return ['key'] + + with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'): + assert cirq.measurement_key_objs(Deprecated()) == {cirq.MeasurementKey('key')} + + with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'): + assert cirq.measurement_key_names(Deprecated()) == {'key'}