Skip to content

Commit 7d3bf5e

Browse files
authored
Add MatrixGate.with_name method. (#5941)
Fixes #5938
1 parent 6fba8be commit 7d3bf5e

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

cirq-core/cirq/ops/matrix_gates.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
*,
5656
name: str = None,
5757
qid_shape: Optional[Iterable[int]] = None,
58+
unitary_check: bool = True,
5859
unitary_check_rtol: float = 1e-5,
5960
unitary_check_atol: float = 1e-8,
6061
) -> None:
@@ -66,6 +67,10 @@ def __init__(
6667
qid_shape: The shape of state tensor that the matrix applies to.
6768
If not specified, this value is inferred by assuming that the
6869
matrix is supposed to apply to qubits.
70+
unitary_check: If True, check that the supplied matrix is unitary up to the
71+
given tolerances. This should only be disabled if the matrix has already been
72+
checked for unitarity, in which case we get a slight performance improvement by
73+
not checking again.
6974
unitary_check_rtol: The relative tolerance for checking whether the supplied matrix
7075
is unitary. See `cirq.is_unitary`.
7176
unitary_check_atol: The absolute tolerance for checking whether the supplied matrix
@@ -99,8 +104,14 @@ def __init__(
99104
f'qid_shape: {self._qid_shape}\n'
100105
)
101106

102-
if not linalg.is_unitary(matrix, rtol=unitary_check_rtol, atol=unitary_check_atol):
103-
raise ValueError(f'Not a unitary matrix: {self._matrix}')
107+
if unitary_check and not linalg.is_unitary(
108+
matrix, rtol=unitary_check_rtol, atol=unitary_check_atol
109+
):
110+
raise ValueError(f'Not a unitary matrix: {matrix}')
111+
112+
def with_name(self, name: str) -> 'MatrixGate':
113+
"""Creates a new MatrixGate with the same matrix and a new name."""
114+
return MatrixGate(self._matrix, name=name, qid_shape=self._qid_shape, unitary_check=False)
104115

105116
def _json_dict_(self) -> Dict[str, Any]:
106117
return {'matrix': self._matrix.tolist(), 'qid_shape': self._qid_shape}

cirq-core/cirq/ops/matrix_gates_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,16 @@ def test_named_two_qubit_diagram():
271271
assert expected_vertical == c.to_text_diagram(transpose=True).strip()
272272

273273

274+
def test_with_name():
275+
gate = cirq.MatrixGate(cirq.unitary(cirq.Z**0.25))
276+
T = gate.with_name('T')
277+
S = (T**2).with_name('S')
278+
assert T._name == 'T'
279+
np.testing.assert_allclose(cirq.unitary(T), cirq.unitary(gate))
280+
assert S._name == 'S'
281+
np.testing.assert_allclose(cirq.unitary(S), cirq.unitary(T**2))
282+
283+
274284
def test_str_executes():
275285
assert '1' in str(cirq.MatrixGate(np.eye(2)))
276286
assert '0' in str(cirq.MatrixGate(np.eye(4)))

0 commit comments

Comments
 (0)