Skip to content

Commit 435f3e1

Browse files
authored
Support numpy qid indexes (#6987)
* Support numpy qid indexes * remove if check on _BaseLineQid.__hash__
1 parent 79d562a commit 435f3e1

File tree

4 files changed

+53
-6
lines changed

4 files changed

+53
-6
lines changed

cirq-core/cirq/devices/grid_qubit.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,9 @@ class _BaseGridQid(ops.Qid):
3434
_col: int
3535
_dimension: int
3636
_comp_key: Optional[Tuple[int, int]] = None
37-
_hash: Optional[int] = None
37+
_hash: int
3838

3939
def __hash__(self) -> int:
40-
if self._hash is None:
41-
self._hash = ((self._dimension - 2) * 1_000_003 + self._col) * 1_000_003 + self._row
4240
return self._hash
4341

4442
def __eq__(self, other) -> bool:
@@ -215,6 +213,9 @@ def __new__(cls, row: int, col: int, *, dimension: int) -> 'cirq.GridQid':
215213
dimension: The dimension of the qid's Hilbert space, i.e.
216214
the number of quantum levels.
217215
"""
216+
row = int(row)
217+
col = int(col)
218+
dimension = int(dimension)
218219
key = (row, col, dimension)
219220
inst = cls._cache.get(key)
220221
if inst is None:
@@ -223,6 +224,7 @@ def __new__(cls, row: int, col: int, *, dimension: int) -> 'cirq.GridQid':
223224
inst._row = row
224225
inst._col = col
225226
inst._dimension = dimension
227+
inst._hash = ((dimension - 2) * 1_000_003 + col) * 1_000_003 + row
226228
cls._cache[key] = inst
227229
return inst
228230

@@ -378,12 +380,15 @@ def __new__(cls, row: int, col: int) -> 'cirq.GridQubit':
378380
row: the row coordinate
379381
col: the column coordinate
380382
"""
383+
row = int(row)
384+
col = int(col)
381385
key = (row, col)
382386
inst = cls._cache.get(key)
383387
if inst is None:
384388
inst = super().__new__(cls)
385389
inst._row = row
386390
inst._col = col
391+
inst._hash = col * 1_000_003 + row
387392
cls._cache[key] = inst
388393
return inst
389394

cirq-core/cirq/devices/grid_qubit_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,24 @@ def test_immutable():
391391
def test_complex():
392392
assert complex(cirq.GridQubit(row=1, col=2)) == 2 + 1j
393393
assert isinstance(complex(cirq.GridQubit(row=1, col=2)), complex)
394+
395+
396+
def test_numpy_index():
397+
np5, np6, np3 = [np.int64(i) for i in [5, 6, 3]]
398+
q = cirq.GridQubit(np5, np6)
399+
hash(q) # doesn't throw
400+
assert q.row == 5
401+
assert q.col == 6
402+
assert q.dimension == 2
403+
assert isinstance(q.row, int)
404+
assert isinstance(q.col, int)
405+
assert isinstance(q.dimension, int)
406+
407+
q = cirq.GridQid(np5, np6, dimension=np3)
408+
hash(q) # doesn't throw
409+
assert q.row == 5
410+
assert q.col == 6
411+
assert q.dimension == 3
412+
assert isinstance(q.row, int)
413+
assert isinstance(q.col, int)
414+
assert isinstance(q.dimension, int)

cirq-core/cirq/devices/line_qubit.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@ class _BaseLineQid(ops.Qid):
3030

3131
_x: int
3232
_dimension: int
33-
_hash: Optional[int] = None
33+
_hash: int
3434

3535
def __hash__(self) -> int:
36-
if self._hash is None:
37-
self._hash = (self._dimension - 2) * 1_000_003 + self._x
3836
return self._hash
3937

4038
def __eq__(self, other) -> bool:
@@ -193,13 +191,16 @@ def __new__(cls, x: int, dimension: int) -> 'cirq.LineQid':
193191
dimension: The dimension of the qid's Hilbert space, i.e.
194192
the number of quantum levels.
195193
"""
194+
x = int(x)
195+
dimension = int(dimension)
196196
key = (x, dimension)
197197
inst = cls._cache.get(key)
198198
if inst is None:
199199
cls.validate_dimension(dimension)
200200
inst = super().__new__(cls)
201201
inst._x = x
202202
inst._dimension = dimension
203+
inst._hash = (dimension - 2) * 1_000_003 + x
203204
cls._cache[key] = inst
204205
return inst
205206

@@ -301,10 +302,12 @@ def __new__(cls, x: int) -> 'cirq.LineQubit':
301302
Args:
302303
x: The x coordinate.
303304
"""
305+
x = int(x)
304306
inst = cls._cache.get(x)
305307
if inst is None:
306308
inst = super().__new__(cls)
307309
inst._x = x
310+
inst._hash = x
308311
cls._cache[x] = inst
309312
return inst
310313

cirq-core/cirq/devices/line_qubit_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
1516
import pytest
1617

1718
import cirq
@@ -284,3 +285,20 @@ def test_numeric():
284285
assert isinstance(int(cirq.LineQubit(x=5)), int)
285286
assert isinstance(float(cirq.LineQubit(x=5)), float)
286287
assert isinstance(complex(cirq.LineQubit(x=5)), complex)
288+
289+
290+
def test_numpy_index():
291+
np5 = np.int64(5)
292+
q = cirq.LineQubit(np5)
293+
assert hash(q) == 5
294+
assert q.x == 5
295+
assert q.dimension == 2
296+
assert isinstance(q.x, int)
297+
assert isinstance(q.dimension, int)
298+
299+
q = cirq.LineQid(np5, np.int64(3))
300+
hash(q) # doesn't throw
301+
assert q.x == 5
302+
assert q.dimension == 3
303+
assert isinstance(q.x, int)
304+
assert isinstance(q.dimension, int)

0 commit comments

Comments
 (0)