From 0a963f9778d731dca17dc6e1562cbe724a807f87 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Fri, 4 Nov 2022 10:06:50 -0700 Subject: [PATCH 1/4] Add MatrixGate.with_name method. Fixes #5938 --- cirq-core/cirq/ops/matrix_gates.py | 28 +++++++++++++++++-------- cirq-core/cirq/ops/matrix_gates_test.py | 10 +++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/cirq-core/cirq/ops/matrix_gates.py b/cirq-core/cirq/ops/matrix_gates.py index c5687bbaed7..17a092efc0c 100644 --- a/cirq-core/cirq/ops/matrix_gates.py +++ b/cirq-core/cirq/ops/matrix_gates.py @@ -55,6 +55,7 @@ def __init__( *, name: str = None, qid_shape: Optional[Iterable[int]] = None, + unitary_check: bool = True, unitary_check_rtol: float = 1e-5, unitary_check_atol: float = 1e-8, ) -> None: @@ -66,6 +67,8 @@ def __init__( qid_shape: The shape of state tensor that the matrix applies to. If not specified, this value is inferred by assuming that the matrix is supposed to apply to qubits. + unitary_check: If True, check that the supplied matrix is unitary up to the + given tolerances. unitary_check_rtol: The relative tolerance for checking whether the supplied matrix is unitary. See `cirq.is_unitary`. unitary_check_atol: The absolute tolerance for checking whether the supplied matrix @@ -88,19 +91,26 @@ def __init__( ) qid_shape = (2,) * n - self._matrix = matrix - self._qid_shape = tuple(qid_shape) - self._name = name - m = int(np.prod(self._qid_shape, dtype=np.int64)) - if self._matrix.shape != (m, m): + m = int(np.prod(qid_shape, dtype=np.int64)) + if matrix.shape != (m, m): raise ValueError( 'Wrong matrix shape for qid_shape.\n' - f'Matrix shape: {self._matrix.shape}\n' - f'qid_shape: {self._qid_shape}\n' + f'Matrix shape: {matrix.shape}\n' + f'qid_shape: {qid_shape}\n' ) - if not linalg.is_unitary(matrix, rtol=unitary_check_rtol, atol=unitary_check_atol): - raise ValueError(f'Not a unitary matrix: {self._matrix}') + if unitary_check and not linalg.is_unitary( + matrix, rtol=unitary_check_rtol, atol=unitary_check_atol + ): + raise ValueError(f'Not a unitary matrix: {matrix}') + + self._matrix = matrix + self._qid_shape = tuple(qid_shape) + self._name = name + + def with_name(self, name: str) -> 'MatrixGate': + """Creates a new MatrixGate with the same matrix and a new name.""" + return MatrixGate(self._matrix, name=name, qid_shape=self._qid_shape, unitary_check=False) def _json_dict_(self) -> Dict[str, Any]: return {'matrix': self._matrix.tolist(), 'qid_shape': self._qid_shape} diff --git a/cirq-core/cirq/ops/matrix_gates_test.py b/cirq-core/cirq/ops/matrix_gates_test.py index 0e66df1f4d0..02123d61ea6 100644 --- a/cirq-core/cirq/ops/matrix_gates_test.py +++ b/cirq-core/cirq/ops/matrix_gates_test.py @@ -271,6 +271,16 @@ def test_named_two_qubit_diagram(): assert expected_vertical == c.to_text_diagram(transpose=True).strip() +def test_with_name(): + gate = cirq.MatrixGate(cirq.unitary(cirq.Z ** 0.25)) + T = gate.with_name('T') + S = (T ** 2).with_name('S') + assert T._name == 'T' + np.testing.assert_allclose(cirq.unitary(T), cirq.unitary(gate)) + assert S._name == 'S' + np.testing.assert_allclose(cirq.unitary(S), cirq.unitary(T ** 2)) + + def test_str_executes(): assert '1' in str(cirq.MatrixGate(np.eye(2))) assert '0' in str(cirq.MatrixGate(np.eye(4))) From 6e898d9d84c28a97646f481b5004b58b36cf09f1 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Sun, 6 Nov 2022 08:09:22 -0800 Subject: [PATCH 2/4] Format --- cirq-core/cirq/ops/matrix_gates_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/ops/matrix_gates_test.py b/cirq-core/cirq/ops/matrix_gates_test.py index 02123d61ea6..8993088ce5a 100644 --- a/cirq-core/cirq/ops/matrix_gates_test.py +++ b/cirq-core/cirq/ops/matrix_gates_test.py @@ -272,13 +272,13 @@ def test_named_two_qubit_diagram(): def test_with_name(): - gate = cirq.MatrixGate(cirq.unitary(cirq.Z ** 0.25)) + gate = cirq.MatrixGate(cirq.unitary(cirq.Z**0.25)) T = gate.with_name('T') - S = (T ** 2).with_name('S') + S = (T**2).with_name('S') assert T._name == 'T' np.testing.assert_allclose(cirq.unitary(T), cirq.unitary(gate)) assert S._name == 'S' - np.testing.assert_allclose(cirq.unitary(S), cirq.unitary(T ** 2)) + np.testing.assert_allclose(cirq.unitary(S), cirq.unitary(T**2)) def test_str_executes(): From ecbdcc0c91250d292fe03f339f29a72f35fdf220 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Sun, 6 Nov 2022 08:12:50 -0800 Subject: [PATCH 3/4] Reorder init method to fix types --- cirq-core/cirq/ops/matrix_gates.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/ops/matrix_gates.py b/cirq-core/cirq/ops/matrix_gates.py index 17a092efc0c..5ae735d3e22 100644 --- a/cirq-core/cirq/ops/matrix_gates.py +++ b/cirq-core/cirq/ops/matrix_gates.py @@ -91,12 +91,15 @@ def __init__( ) qid_shape = (2,) * n - m = int(np.prod(qid_shape, dtype=np.int64)) - if matrix.shape != (m, m): + self._matrix = matrix + self._qid_shape = tuple(qid_shape) + self._name = name + m = int(np.prod(self._qid_shape, dtype=np.int64)) + if self._matrix.shape != (m, m): raise ValueError( 'Wrong matrix shape for qid_shape.\n' - f'Matrix shape: {matrix.shape}\n' - f'qid_shape: {qid_shape}\n' + f'Matrix shape: {self._matrix.shape}\n' + f'qid_shape: {self._qid_shape}\n' ) if unitary_check and not linalg.is_unitary( @@ -104,10 +107,6 @@ def __init__( ): raise ValueError(f'Not a unitary matrix: {matrix}') - self._matrix = matrix - self._qid_shape = tuple(qid_shape) - self._name = name - def with_name(self, name: str) -> 'MatrixGate': """Creates a new MatrixGate with the same matrix and a new name.""" return MatrixGate(self._matrix, name=name, qid_shape=self._qid_shape, unitary_check=False) From dbe432a7f4b1f15b7a5dff99c21155b4504f7706 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 8 Nov 2022 14:42:53 -0800 Subject: [PATCH 4/4] Add note about when appropriate to disable unitary_check --- cirq-core/cirq/ops/matrix_gates.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/ops/matrix_gates.py b/cirq-core/cirq/ops/matrix_gates.py index 5ae735d3e22..7ee9cb13cc6 100644 --- a/cirq-core/cirq/ops/matrix_gates.py +++ b/cirq-core/cirq/ops/matrix_gates.py @@ -68,7 +68,9 @@ def __init__( If not specified, this value is inferred by assuming that the matrix is supposed to apply to qubits. unitary_check: If True, check that the supplied matrix is unitary up to the - given tolerances. + given tolerances. This should only be disabled if the matrix has already been + checked for unitarity, in which case we get a slight performance improvement by + not checking again. unitary_check_rtol: The relative tolerance for checking whether the supplied matrix is unitary. See `cirq.is_unitary`. unitary_check_atol: The absolute tolerance for checking whether the supplied matrix