Skip to content

Commit 80a63b9

Browse files
authored
Fix np.einsum type annotations (#7184)
* Clean redundant type-ignores in np.einsum calls Fixes #5757 * Clean up obsolete comments
1 parent c306eaf commit 80a63b9

File tree

3 files changed

+4
-10
lines changed

3 files changed

+4
-10
lines changed

cirq-core/cirq/experiments/readout_confusion_matrix.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,12 @@ def _get_vars(self, qubit_pattern: Sequence['cirq.Qid']) -> List[int]:
173173
return in_vars + out_vars
174174

175175
def _confusion_matrix(self, qubits: Sequence['cirq.Qid']) -> np.ndarray:
176-
ein_input = []
176+
ein_input: List[np.ndarray | List[int]] = []
177177
for qs, cm in zip(self.measure_qubits, self.confusion_matrices):
178178
ein_input.extend([cm.reshape((2, 2) * len(qs)), self._get_vars(qs)])
179179
ein_out = self._get_vars(qubits)
180180

181-
# TODO(#5757): remove type ignore when numpy has proper override signature.
182-
ret = np.einsum(*ein_input, ein_out).reshape((2 ** len(qubits),) * 2) # type: ignore
181+
ret = np.einsum(*ein_input, ein_out).reshape((2 ** len(qubits),) * 2)
183182
return ret / ret.sum(axis=1)
184183

185184
def confusion_matrix(self, qubits: Optional[Sequence['cirq.Qid']] = None) -> np.ndarray:

cirq-core/cirq/linalg/transformations.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def targeted_left_multiply(
153153

154154
all_indices = set(input_indices + data_indices + tuple(output_indices))
155155

156-
# TODO(#5757): remove type ignore when numpy has proper override signature.
157156
return np.einsum(
158157
left_matrix,
159158
input_indices,
@@ -164,10 +163,8 @@ def targeted_left_multiply(
164163
# but this is a workaround for a bug in numpy:
165164
# https://github.com/numpy/numpy/issues/10926
166165
optimize=len(all_indices) >= 26,
167-
# And this is workaround for *another* bug!
168-
# Supposed to be able to just say 'old=old'.
169-
**({'out': out} if out is not None else {}),
170-
) # type: ignore
166+
out=out,
167+
)
171168

172169

173170
@dataclasses.dataclass
@@ -412,7 +409,6 @@ def partial_trace(tensor: np.ndarray, keep_indices: Sequence[int]) -> np.ndarray
412409
keep_map = dict(zip(keep_indices, sorted(keep_indices)))
413410
left_indices = [keep_map[i] if i in keep_set else i for i in range(ndim)]
414411
right_indices = [ndim + i if i in keep_set else i for i in left_indices]
415-
# TODO(#5757): remove type ignore when numpy has proper override signature.
416412
return np.einsum(tensor, left_indices + right_indices)
417413

418414

cirq-core/cirq/qis/states.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,6 @@ def density_matrix_from_state_vector(
677677
sum_inds = np.array(range(n_qubits))
678678
sum_inds[indices] += n_qubits
679679

680-
# TODO(#5757): remove type ignore when numpy has proper override signature.
681680
rho = np.einsum(
682681
state_vector,
683682
list(range(n_qubits)),

0 commit comments

Comments
 (0)