Skip to content

Commit 00765f0

Browse files
committed
typecheck - support numpy scalar types in LinearDict
1 parent 787a6c2 commit 00765f0

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

cirq-core/cirq/value/linear_dict.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
)
3535
from typing_extensions import Self
3636

37-
Scalar = complex
37+
import numpy as np
38+
39+
Scalar = Union[complex, np.number]
3840
TVector = TypeVar('TVector')
3941

4042
TDefault = TypeVar('TDefault')
@@ -124,7 +126,7 @@ def _check_vector_valid(self, vector: TVector) -> None:
124126

125127
def clean(self, *, atol: float = 1e-9) -> Self:
126128
"""Remove terms with coefficients of absolute value atol or less."""
127-
negligible = [v for v, c in self._terms.items() if abs(c) <= atol]
129+
negligible = [v for v, c in self._terms.items() if abs(complex(c)) <= atol]
128130
for v in negligible:
129131
del self._terms[v]
130132
return self
@@ -245,7 +247,7 @@ def __mul__(self, a: Scalar) -> Self:
245247
result *= a
246248
return result
247249

248-
def __rmul__(self, a: Scalar) -> Self:
250+
def __rmul__(self, a: Scalar) -> Self: # type: ignore
249251
return self.__mul__(a)
250252

251253
def __truediv__(self, a: Scalar) -> Self:

0 commit comments

Comments
 (0)