Skip to content

Commit 5377fff

Browse files
authored
Test hash consistency of cirq objects loaded from a pickle (#6677)
Ensure cirq objects passed as pickles in multiprocessing calls work consistently in dictionaries just as the in-process created objects. Partially resolves #6674
1 parent 51e8c3d commit 5377fff

21 files changed

+261
-14
lines changed

cirq-core/cirq/circuits/circuit_operation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import math
2222
from functools import cached_property
2323
from typing import (
24+
Any,
2425
Callable,
2526
cast,
2627
Dict,
@@ -508,6 +509,16 @@ def _hash(self) -> int:
508509
def __hash__(self) -> int:
509510
return self._hash
510511

512+
def __getstate__(self) -> Dict[str, Any]:
513+
# clear cached hash value when pickling, see #6674
514+
state = self.__dict__
515+
# cached_property stores value in the property-named attribute
516+
hash_attr = "_hash"
517+
if hash_attr in state:
518+
state = state.copy()
519+
del state[hash_attr]
520+
return state
521+
511522
def _json_dict_(self):
512523
resp = {
513524
'circuit': self.circuit,

cirq-core/cirq/circuits/frozen_circuit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ def __eq__(self, other):
119119
def __getstate__(self):
120120
# Don't save hash when pickling; see #3777.
121121
state = self.__dict__
122-
hash_cache = _compat._method_cache_name(self.__hash__)
123-
if hash_cache in state:
122+
hash_attr = _compat._method_cache_name(self.__hash__)
123+
if hash_attr in state:
124124
state = state.copy()
125-
del state[hash_cache]
125+
del state[hash_attr]
126126
return state
127127

128128
@_compat.cached_method

cirq-core/cirq/circuits/moment.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,15 @@ def __ne__(self, other) -> bool:
365365
def __hash__(self):
366366
return hash((Moment, self._sorted_operations_()))
367367

368+
def __getstate__(self) -> Dict[str, Any]:
369+
# clear cached hash value when pickling, see #6674
370+
state = self.__dict__
371+
hash_attr = _compat._method_cache_name(self.__hash__)
372+
if hash_attr in state:
373+
state = state.copy()
374+
del state[hash_attr]
375+
return state
376+
368377
def __iter__(self) -> Iterator['cirq.Operation']:
369378
return iter(self.operations)
370379

cirq-core/cirq/devices/grid_qubit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def __getnewargs_ex__(self):
230230
"""Returns a tuple of (args, kwargs) to pass to __new__ when unpickling."""
231231
return (self._row, self._col), {"dimension": self._dimension}
232232

233+
# avoid pickling the _hash value, attributes are already stored with __getnewargs_ex__
234+
def __getstate__(self) -> Dict[str, Any]:
235+
return {}
236+
233237
def _with_row_col(self, row: int, col: int) -> 'GridQid':
234238
return GridQid(row, col, dimension=self._dimension)
235239

@@ -387,6 +391,10 @@ def __getnewargs__(self):
387391
"""Returns a tuple of args to pass to __new__ when unpickling."""
388392
return (self._row, self._col)
389393

394+
# avoid pickling the _hash value, attributes are already stored with __getnewargs__
395+
def __getstate__(self) -> Dict[str, Any]:
396+
return {}
397+
390398
def _with_row_col(self, row: int, col: int) -> 'GridQubit':
391399
return GridQubit(row, col)
392400

cirq-core/cirq/devices/line_qubit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def __getnewargs__(self):
207207
"""Returns a tuple of args to pass to __new__ when unpickling."""
208208
return (self._x, self._dimension)
209209

210+
# avoid pickling the _hash value, attributes are already stored with __getnewargs__
211+
def __getstate__(self) -> Dict[str, Any]:
212+
return {}
213+
210214
def _with_x(self, x: int) -> 'LineQid':
211215
return LineQid(x, dimension=self._dimension)
212216

@@ -308,6 +312,10 @@ def __getnewargs__(self):
308312
"""Returns a tuple of args to pass to __new__ when unpickling."""
309313
return (self._x,)
310314

315+
# avoid pickling the _hash value, attributes are already stored with __getnewargs__
316+
def __getstate__(self) -> Dict[str, Any]:
317+
return {}
318+
311319
def _with_x(self, x: int) -> 'LineQubit':
312320
return LineQubit(x)
313321

cirq-core/cirq/ops/boolean_hamiltonian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _qid_shape_(self) -> Tuple[int, ...]:
8080
return (2,) * len(self._parameter_names)
8181

8282
def _value_equality_values_(self) -> Any:
83-
return self._parameter_names, self._boolean_strs, self._theta
83+
return tuple(self._parameter_names), tuple(self._boolean_strs), self._theta
8484

8585
def _json_dict_(self) -> Dict[str, Any]:
8686
return {

cirq-core/cirq/ops/common_channels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _has_mixture_(self) -> bool:
131131
return True
132132

133133
def _value_equality_values_(self):
134-
return self._num_qubits, hash(tuple(sorted(self._error_probabilities.items())))
134+
return self._num_qubits, tuple(sorted(self._error_probabilities.items()))
135135

136136
def __repr__(self) -> str:
137137
return 'cirq.asymmetric_depolarize(' + f"error_probabilities={self._error_probabilities})"

cirq-core/cirq/ops/named_qubit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ def __getnewargs__(self):
134134
"""Returns a tuple of args to pass to __new__ when unpickling."""
135135
return (self._name, self._dimension)
136136

137+
# avoid pickling the _hash value, attributes are already stored with __getnewargs__
138+
def __getstate__(self) -> Dict[str, Any]:
139+
return {}
140+
137141
def __repr__(self) -> str:
138142
return f'cirq.NamedQid({self._name!r}, dimension={self._dimension})'
139143

@@ -202,6 +206,10 @@ def __getnewargs__(self):
202206
"""Returns a tuple of args to pass to __new__ when unpickling."""
203207
return (self._name,)
204208

209+
# avoid pickling the _hash value, attributes are already stored with __getnewargs__
210+
def __getstate__(self) -> Dict[str, Any]:
211+
return {}
212+
205213
def __str__(self) -> str:
206214
return self._name
207215

cirq-core/cirq/ops/raw_types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from cirq import protocols, value
4343
from cirq._import import LazyLoader
44-
from cirq._compat import __cirq_debug__, cached_method
44+
from cirq._compat import __cirq_debug__, _method_cache_name, cached_method
4545
from cirq.type_workarounds import NotImplementedType
4646
from cirq.ops import control_values as cv
4747

@@ -115,6 +115,15 @@ def _cmp_tuple(self):
115115
def __hash__(self) -> int:
116116
return hash((Qid, self._comparison_key()))
117117

118+
def __getstate__(self) -> Dict[str, Any]:
119+
# clear cached hash value when pickling, see #6674
120+
state = self.__dict__
121+
hash_attr = _method_cache_name(self.__hash__)
122+
if hash_attr in state:
123+
state = state.copy()
124+
del state[hash_attr]
125+
return state
126+
118127
def __eq__(self, other):
119128
if not isinstance(other, Qid):
120129
return NotImplemented
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import multiprocessing
18+
import os
19+
import pathlib
20+
import pickle
21+
from collections.abc import Iterator
22+
from typing import Any, Hashable
23+
24+
import pytest
25+
26+
import cirq
27+
from cirq.protocols.json_serialization_test import MODULE_TEST_SPECS
28+
29+
_EXCLUDE_JSON_FILES = (
30+
# sympy - related objects
31+
"cirq/protocols/json_test_data/sympy.Add.json",
32+
"cirq/protocols/json_test_data/sympy.E.json",
33+
"cirq/protocols/json_test_data/sympy.Equality.json",
34+
"cirq/protocols/json_test_data/sympy.EulerGamma.json",
35+
"cirq/protocols/json_test_data/sympy.Float.json",
36+
"cirq/protocols/json_test_data/sympy.GreaterThan.json",
37+
"cirq/protocols/json_test_data/sympy.Integer.json",
38+
"cirq/protocols/json_test_data/sympy.LessThan.json",
39+
"cirq/protocols/json_test_data/sympy.Mul.json",
40+
"cirq/protocols/json_test_data/sympy.Pow.json",
41+
"cirq/protocols/json_test_data/sympy.Rational.json",
42+
"cirq/protocols/json_test_data/sympy.StrictGreaterThan.json",
43+
"cirq/protocols/json_test_data/sympy.StrictLessThan.json",
44+
"cirq/protocols/json_test_data/sympy.Symbol.json",
45+
"cirq/protocols/json_test_data/sympy.Unequality.json",
46+
"cirq/protocols/json_test_data/sympy.pi.json",
47+
# RigettiQCSAspenDevice does not pickle
48+
"cirq_rigetti/json_test_data/RigettiQCSAspenDevice.json",
49+
# TODO(#6674,pavoljuhas) - fix pickling of ProjectorSum
50+
"cirq/protocols/json_test_data/ProjectorSum.json",
51+
)
52+
53+
54+
def _is_included(json_filename: str) -> bool:
55+
json_posix_path = pathlib.PurePath(json_filename).as_posix()
56+
if any(json_posix_path.endswith(t) for t in _EXCLUDE_JSON_FILES):
57+
return False
58+
if not os.path.isfile(json_filename):
59+
return False
60+
return True
61+
62+
63+
@pytest.fixture(scope='module')
64+
def pool() -> Iterator[multiprocessing.pool.Pool]:
65+
ctx = multiprocessing.get_context("spawn")
66+
with ctx.Pool(1) as pool:
67+
yield pool
68+
69+
70+
def _read_json(json_filename: str) -> Any:
71+
obj = cirq.read_json(json_filename)
72+
obj = obj[0] if isinstance(obj, list) else obj
73+
# trigger possible caching of the hash value
74+
if isinstance(obj, Hashable):
75+
_ = hash(obj)
76+
return obj
77+
78+
79+
def test_exclude_json_files_has_valid_entries() -> None:
80+
"""Verify _EXCLUDE_JSON_FILES has valid entries."""
81+
# do not check rigetti files if not installed
82+
skip_rigetti = all(m.name != "cirq_rigetti" for m in MODULE_TEST_SPECS)
83+
json_file_validates = lambda f: any(
84+
m.test_data_path.joinpath(os.path.basename(f)).is_file() for m in MODULE_TEST_SPECS
85+
) or (skip_rigetti and f.startswith("cirq_rigetti/"))
86+
invalid_json_paths = [f for f in _EXCLUDE_JSON_FILES if not json_file_validates(f)]
87+
assert invalid_json_paths == []
88+
89+
90+
@pytest.mark.parametrize(
91+
'json_filename',
92+
[
93+
f"{abs_path}.json"
94+
for m in MODULE_TEST_SPECS
95+
for abs_path in m.all_test_data_keys()
96+
if _is_included(f"{abs_path}.json")
97+
],
98+
)
99+
def test_hash_from_pickle(json_filename: str, pool: multiprocessing.pool.Pool):
100+
obj_local = _read_json(json_filename)
101+
if not isinstance(obj_local, Hashable):
102+
return
103+
# check if pickling works in the main process for the sake of debugging
104+
obj_copy = pickle.loads(pickle.dumps(obj_local))
105+
assert obj_copy == obj_local
106+
assert hash(obj_copy) == hash(obj_local)
107+
# Read and hash the object in a separate worker process and then
108+
# send it back which requires pickling and unpickling.
109+
obj_worker = pool.apply(_read_json, [json_filename])
110+
assert obj_worker == obj_local
111+
assert hash(obj_worker) == hash(obj_local)

cirq-core/cirq/qis/clifford_tableau.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818

1919
from cirq import protocols
20-
from cirq._compat import proper_repr, cached_method
20+
from cirq._compat import proper_repr, _method_cache_name, cached_method
2121
from cirq.qis import quantum_state_representation
2222
from cirq.value import big_endian_int_to_digits, linear_dict, random_state
2323

@@ -658,3 +658,12 @@ def measure(
658658
@cached_method
659659
def __hash__(self) -> int:
660660
return hash(self.matrix().tobytes() + self.rs.tobytes())
661+
662+
def __getstate__(self) -> Dict[str, Any]:
663+
# clear cached hash value when pickling, see #6674
664+
state = self.__dict__
665+
hash_attr = _method_cache_name(self.__hash__)
666+
if hash_attr in state:
667+
state = state.copy()
668+
del state[hash_attr]
669+
return state

cirq-core/cirq/sim/clifford/clifford_simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def state(self):
188188
return self._clifford_state
189189

190190

191-
@value.value_equality
191+
@value.value_equality(unhashable=True)
192192
class CliffordState:
193193
"""A state of the Clifford simulation.
194194

cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from cirq.value import big_endian_int_to_digits, random_state
2121

2222

23-
@value.value_equality
23+
@value.value_equality(unhashable=True)
2424
class StabilizerStateChForm(qis.StabilizerState):
2525
r"""A representation of stabilizer states using the CH form,
2626

cirq-core/cirq/study/resolver.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ def __hash__(self) -> int:
244244
self._param_hash = hash(frozenset(self._param_dict.items()))
245245
return self._param_hash
246246

247+
def __getstate__(self) -> Dict[str, Any]:
248+
# clear cached hash value when pickling, see #6674
249+
state = self.__dict__
250+
if state["_param_hash"] is not None:
251+
state = state.copy()
252+
state["_param_hash"] = None
253+
return state
254+
247255
def __eq__(self, other):
248256
if not isinstance(other, ParamResolver):
249257
return NotImplemented

cirq-core/cirq/value/measurement_key.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import FrozenSet, Mapping, Optional, Tuple
15+
from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple
1616

1717
import dataclasses
1818

@@ -77,6 +77,14 @@ def __hash__(self):
7777
object.__setattr__(self, '_hash', hash(str(self)))
7878
return self._hash
7979

80+
def __getstate__(self) -> Dict[str, Any]:
81+
# clear cached hash value when pickling, see #6674
82+
state = self.__dict__
83+
if "_hash" in state:
84+
state = state.copy()
85+
del state["_hash"]
86+
return state
87+
8088
def __lt__(self, other):
8189
if isinstance(other, MeasurementKey):
8290
if self.path != other.path:

cirq-core/cirq/value/value_equality_attr.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Defines `@cirq.value_equality`, for easy __eq__/__hash__ methods."""
1515

16-
from typing import Any, Callable, Optional, overload, Union
16+
from typing import Any, Callable, Dict, Optional, overload, Union
1717

1818
from typing_extensions import Protocol
1919

@@ -110,6 +110,16 @@ def _value_equality_approx_eq(
110110
)
111111

112112

113+
def _value_equality_getstate(self: _SupportsValueEquality) -> Dict[str, Any]:
114+
# clear cached hash value when pickling, see #6674
115+
state = self.__dict__
116+
hash_attr = _compat._method_cache_name(self.__hash__)
117+
if hash_attr in state:
118+
state = state.copy()
119+
del state[hash_attr]
120+
return state
121+
122+
113123
# pylint: disable=function-redefined
114124
@overload
115125
def value_equality(
@@ -228,6 +238,8 @@ class return the existing class' type.
228238
cached_values_getter = values_getter if unhashable else _compat.cached_method(values_getter)
229239
setattr(cls, '_value_equality_values_', cached_values_getter)
230240
setattr(cls, '__hash__', None if unhashable else _compat.cached_method(_value_equality_hash))
241+
if not unhashable:
242+
setattr(cls, '__getstate__', _value_equality_getstate)
231243
setattr(cls, '__eq__', _value_equality_eq)
232244
setattr(cls, '__ne__', _value_equality_ne)
233245

cirq-core/cirq/work/observable_measurement_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ class ObservableMeasuredResult:
109109
repetitions: int
110110
circuit_params: Mapping[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]
111111

112+
# unhashable because of the mapping-type circuit_params attribute
113+
__hash__ = None # type: ignore
114+
112115
def __repr__(self):
113116
# I wish we could use the default dataclass __repr__ but
114117
# we need to prefix our class name with `cirq.work.`

0 commit comments

Comments
 (0)