diff --git a/cirq-core/cirq/contrib/qasm_import/_lexer.py b/cirq-core/cirq/contrib/qasm_import/_lexer.py index 1a830ffc17d..206d9e88d74 100644 --- a/cirq-core/cirq/contrib/qasm_import/_lexer.py +++ b/cirq-core/cirq/contrib/qasm_import/_lexer.py @@ -32,7 +32,7 @@ def __init__(self): 'measure': 'MEASURE', 'if': 'IF', '->': 'ARROW', - '!=': 'NE', + '==': 'EQ', } tokens = ['FORMAT_SPEC', 'NUMBER', 'NATURAL_NUMBER', 'QELIBINC', 'ID', 'PI'] + list( @@ -103,8 +103,8 @@ def t_ARROW(self, t): """->""" return t - def t_NE(self, t): - """!=""" + def t_EQ(self, t): + """==""" return t def t_ID(self, t): diff --git a/cirq-core/cirq/contrib/qasm_import/_parser.py b/cirq-core/cirq/contrib/qasm_import/_parser.py index 7c779432edc..0b79ff346d4 100644 --- a/cirq-core/cirq/contrib/qasm_import/_parser.py +++ b/cirq-core/cirq/contrib/qasm_import/_parser.py @@ -16,6 +16,7 @@ from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Union import numpy as np +import sympy from ply import yacc from cirq import ops, Circuit, NamedQubit, CX @@ -496,11 +497,19 @@ def p_measurement(self, p): ] # if operations - # if : IF '(' carg NE NATURAL_NUMBER ')' ID qargs + # if : IF '(' carg EQ NATURAL_NUMBER ')' ID qargs def p_if(self, p): - """if : IF '(' carg NE NATURAL_NUMBER ')' gate_op""" - p[0] = [ops.ClassicallyControlledOperation(conditions=p[3], sub_operation=tuple(p[7])[0])] + """if : IF '(' carg EQ NATURAL_NUMBER ')' gate_op""" + # We have to split the register into bits (since that's what measurement does above), + # and create one condition per bit, checking against that part of the binary value. + conditions = [] + for i, key in enumerate(p[3]): + v = (p[5] >> i) & 1 + conditions.append(sympy.Eq(sympy.Symbol(key), v)) + p[0] = [ + ops.ClassicallyControlledOperation(conditions=conditions, sub_operation=tuple(p[7])[0]) + ] def p_error(self, p): if p is None: diff --git a/cirq-core/cirq/contrib/qasm_import/_parser_test.py b/cirq-core/cirq/contrib/qasm_import/_parser_test.py index 19cb839dadb..0818e43ddac 100644 --- a/cirq-core/cirq/contrib/qasm_import/_parser_test.py +++ b/cirq-core/cirq/contrib/qasm_import/_parser_test.py @@ -215,16 +215,17 @@ def test_CX_gate(): def test_classical_control(): qasm = """OPENQASM 2.0; qreg q[2]; - creg m_a[1]; - measure q[0] -> m_a[0]; - if (m_a!=0) CX q[0], q[1]; + creg a[1]; + measure q[0] -> a[0]; + if (a==1) CX q[0],q[1]; """ parser = QasmParser() q_0 = cirq.NamedQubit('q_0') q_1 = cirq.NamedQubit('q_1') expected_circuit = cirq.Circuit( - cirq.measure(q_0, key='m_a_0'), cirq.CNOT(q_0, q_1).with_classical_controls('m_a_0') + cirq.measure(q_0, key='a_0'), + cirq.CNOT(q_0, q_1).with_classical_controls(sympy.Eq(sympy.Symbol('a_0'), 1)), ) parsed_qasm = parser.parse(qasm) @@ -235,6 +236,62 @@ def test_classical_control(): ct.assert_same_circuits(parsed_qasm.circuit, expected_circuit) assert parsed_qasm.qregs == {'q': 2} + # Note this cannot *exactly* round-trip because the way QASM and Cirq handle measurements + # into classical registers is different. Cirq parses QASM classical registers into m_a_i for i + # in 0..bit_count. Thus the generated key has an extra "_0" at the end. + expected_generated_qasm = f"""// Generated from Cirq v{cirq.__version__} + +OPENQASM 2.0; +include "qelib1.inc"; + + +// Qubits: [q_0, q_1] +qreg q[2]; +creg m_a_0[1]; + + +measure q[0] -> m_a_0[0]; +if (m_a_0==1) cx q[0],q[1]; +""" + assert cirq.qasm(parsed_qasm.circuit) == expected_generated_qasm + + +def test_classical_control_multi_bit(): + qasm = """OPENQASM 2.0; + qreg q[2]; + creg a[2]; + measure q[0] -> a[0]; + measure q[0] -> a[1]; + if (a==1) CX q[0],q[1]; + """ + parser = QasmParser() + + q_0 = cirq.NamedQubit('q_0') + q_1 = cirq.NamedQubit('q_1') + + # Since we split the measurement into two, we also need two conditions. + # m_a==1 corresponds to m_a[0]==1, m_a[1]==0 + expected_circuit = cirq.Circuit( + cirq.measure(q_0, key='a_0'), + cirq.measure(q_0, key='a_1'), + cirq.CNOT(q_0, q_1).with_classical_controls( + sympy.Eq(sympy.Symbol('a_0'), 1), sympy.Eq(sympy.Symbol('a_1'), 0) + ), + ) + + parsed_qasm = parser.parse(qasm) + + assert parsed_qasm.supportedFormat + assert not parsed_qasm.qelib1Include + + ct.assert_same_circuits(parsed_qasm.circuit, expected_circuit) + assert parsed_qasm.qregs == {'q': 2} + + # Note that this will *not* round-trip, but there's no good way around that due to the + # difference in how Cirq and QASM do multi-bit measurements. + with pytest.raises(ValueError, match='QASM does not support multiple conditions'): + _ = cirq.qasm(parsed_qasm.circuit) + def test_CX_gate_not_enough_args(): qasm = """OPENQASM 2.0; diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 0c1edf2fa3b..6b0f701f64d 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -207,5 +207,9 @@ def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: args.validate_version('2.0') - all_keys = " && ".join(c.qasm for c in self._conditions) - return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args)) + if len(self._conditions) > 1: + raise ValueError('QASM does not support multiple conditions.') + subop_qasm = protocols.qasm(self._sub_operation, args=args) + if not self._conditions: + return subop_qasm + return f'if ({self._conditions[0].qasm}) {subop_qasm}' diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index abaf963ff37..d23398a0012 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -198,11 +198,14 @@ def test_diagram_subcircuit_layered(): def test_qasm(): q0, q1 = cirq.LineQubit.range(2) - circuit = cirq.Circuit(cirq.measure(q0, key='a'), cirq.X(q1).with_classical_controls('a')) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls(sympy.Eq(sympy.Symbol('a'), 0)), + ) qasm = cirq.qasm(circuit) assert ( qasm - == """// Generated from Cirq v0.15.0.dev + == f"""// Generated from Cirq v{cirq.__version__} OPENQASM 2.0; include "qelib1.inc"; @@ -214,11 +217,49 @@ def test_qasm(): measure q[0] -> m_a[0]; -if (m_a!=0) x q[1]; +if (m_a==0) x q[1]; """ ) +def test_qasm_no_conditions(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), cirq.ClassicallyControlledOperation(cirq.X(q1), []) + ) + qasm = cirq.qasm(circuit) + assert ( + qasm + == f"""// Generated from Cirq v{cirq.__version__} + +OPENQASM 2.0; +include "qelib1.inc"; + + +// Qubits: [q(0), q(1)] +qreg q[2]; +creg m_a[1]; + + +measure q[0] -> m_a[0]; +x q[1]; +""" + ) + + +def test_qasm_multiple_conditions(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q0, key='b'), + cirq.X(q1).with_classical_controls( + sympy.Eq(sympy.Symbol('a'), 0), sympy.Eq(sympy.Symbol('b'), 0) + ), + ) + with pytest.raises(ValueError, match='QASM does not support multiple conditions'): + _ = cirq.qasm(circuit) + + @pytest.mark.parametrize('sim', ALL_SIMULATORS) def test_key_unset(sim): q0, q1 = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index 75a2bb62196..df398be2d6e 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -113,9 +113,7 @@ def _from_json_dict_(cls, key, **kwargs): @property def qasm(self): - if self.index != -1: - raise NotImplementedError('Only most recent measurement at key can be used for QASM.') - return f'm_{self.key}!=0' + raise ValueError('QASM is defined only for SympyConditions of type key == constant.') @dataclasses.dataclass(frozen=True) @@ -162,4 +160,8 @@ def _from_json_dict_(cls, expr, **kwargs): @property def qasm(self): - raise NotImplementedError() + if isinstance(self.expr, sympy.Equality): + if isinstance(self.expr.lhs, sympy.Symbol) and isinstance(self.expr.rhs, sympy.Integer): + # Measurements get prepended with "m_", so the condition needs to be too. + return f'm_{self.expr.lhs}=={self.expr.rhs}' + raise ValueError('QASM is defined only for SympyConditions of type key == constant.') diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index 29148853994..aae218bf993 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -67,7 +67,8 @@ def resolve(records): def test_key_condition_qasm(): - assert cirq.KeyCondition(cirq.MeasurementKey('a')).qasm == 'm_a!=0' + with pytest.raises(ValueError, match='QASM is defined only for SympyConditions'): + _ = cirq.KeyCondition(cirq.MeasurementKey('a')).qasm def test_sympy_condition_with_keys(): @@ -111,5 +112,9 @@ def resolve(records): def test_sympy_condition_qasm(): - with pytest.raises(NotImplementedError): - _ = init_sympy_condition.qasm + # Measurements get prepended with "m_", so the condition needs to be too. + assert cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), 2)).qasm == 'm_a==2' + with pytest.raises( + ValueError, match='QASM is defined only for SympyConditions of type key == constant' + ): + _ = cirq.SympyCondition(sympy.Symbol('a') != 2).qasm