Skip to content

Fix qasm generation/parsing for classical controls #5434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 27, 2022
6 changes: 3 additions & 3 deletions cirq-core/cirq/contrib/qasm_import/_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self):
'measure': 'MEASURE',
'if': 'IF',
'->': 'ARROW',
'!=': 'NE',
'==': 'EQ',
}

tokens = ['FORMAT_SPEC', 'NUMBER', 'NATURAL_NUMBER', 'QELIBINC', 'ID', 'PI'] + list(
Expand Down Expand Up @@ -103,8 +103,8 @@ def t_ARROW(self, t):
"""->"""
return t

def t_NE(self, t):
"""!="""
def t_EQ(self, t):
"""=="""
return t

def t_ID(self, t):
Expand Down
15 changes: 12 additions & 3 deletions cirq-core/cirq/contrib/qasm_import/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Union

import numpy as np
import sympy
from ply import yacc

from cirq import ops, Circuit, NamedQubit, CX
Expand Down Expand Up @@ -496,11 +497,19 @@ def p_measurement(self, p):
]

# if operations
# if : IF '(' carg NE NATURAL_NUMBER ')' ID qargs
# if : IF '(' carg EQ NATURAL_NUMBER ')' ID qargs

def p_if(self, p):
"""if : IF '(' carg NE NATURAL_NUMBER ')' gate_op"""
p[0] = [ops.ClassicallyControlledOperation(conditions=p[3], sub_operation=tuple(p[7])[0])]
"""if : IF '(' carg EQ NATURAL_NUMBER ')' gate_op"""
# We have to split the register into bits (since that's what measurement does above),
# and create one condition per bit, checking against that part of the binary value.
conditions = []
for i, key in enumerate(p[3]):
v = (p[5] >> i) & 1
conditions.append(sympy.Eq(sympy.Symbol(key), v))
p[0] = [
ops.ClassicallyControlledOperation(conditions=conditions, sub_operation=tuple(p[7])[0])
]

def p_error(self, p):
if p is None:
Expand Down
65 changes: 61 additions & 4 deletions cirq-core/cirq/contrib/qasm_import/_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,17 @@ def test_CX_gate():
def test_classical_control():
qasm = """OPENQASM 2.0;
qreg q[2];
creg m_a[1];
measure q[0] -> m_a[0];
if (m_a!=0) CX q[0], q[1];
creg a[1];
measure q[0] -> a[0];
if (a==1) CX q[0],q[1];
"""
parser = QasmParser()

q_0 = cirq.NamedQubit('q_0')
q_1 = cirq.NamedQubit('q_1')
expected_circuit = cirq.Circuit(
cirq.measure(q_0, key='m_a_0'), cirq.CNOT(q_0, q_1).with_classical_controls('m_a_0')
cirq.measure(q_0, key='a_0'),
cirq.CNOT(q_0, q_1).with_classical_controls(sympy.Eq(sympy.Symbol('a_0'), 1)),
)

parsed_qasm = parser.parse(qasm)
Expand All @@ -233,6 +234,62 @@ def test_classical_control():
ct.assert_same_circuits(parsed_qasm.circuit, expected_circuit)
assert parsed_qasm.qregs == {'q': 2}

# Note this cannot *exactly* round-trip because the way QASM and Cirq handle measurements
# into classical registers is different. Cirq parses QASM classical registers into m_a_i for i
# in 0..bit_count. Thus the generated key has an extra "_0" at the end.
expected_generated_qasm = f"""// Generated from Cirq v{cirq.__version__}

OPENQASM 2.0;
include "qelib1.inc";


// Qubits: [q_0, q_1]
qreg q[2];
creg m_a_0[1];


measure q[0] -> m_a_0[0];
if (m_a_0==1) cx q[0],q[1];
"""
assert cirq.qasm(parsed_qasm.circuit) == expected_generated_qasm


def test_classical_control_multi_bit():
qasm = """OPENQASM 2.0;
qreg q[2];
creg a[2];
measure q[0] -> a[0];
measure q[0] -> a[1];
if (a==1) CX q[0],q[1];
"""
parser = QasmParser()

q_0 = cirq.NamedQubit('q_0')
q_1 = cirq.NamedQubit('q_1')

# Since we split the measurement into two, we also need two conditions.
# m_a==1 corresponds to m_a[0]==1, m_a[1]==0
expected_circuit = cirq.Circuit(
cirq.measure(q_0, key='a_0'),
cirq.measure(q_0, key='a_1'),
cirq.CNOT(q_0, q_1).with_classical_controls(
sympy.Eq(sympy.Symbol('a_0'), 1), sympy.Eq(sympy.Symbol('a_1'), 0)
),
)

parsed_qasm = parser.parse(qasm)

assert parsed_qasm.supportedFormat
assert not parsed_qasm.qelib1Include

ct.assert_same_circuits(parsed_qasm.circuit, expected_circuit)
assert parsed_qasm.qregs == {'q': 2}

# Note that this will *not* round-trip, but there's no good way around that due to the
# difference in how Cirq and QASM do multi-bit measurements.
with pytest.raises(ValueError, match='QASM does not support multiple conditions'):
_ = cirq.qasm(parsed_qasm.circuit)


def test_CX_gate_not_enough_args():
qasm = """OPENQASM 2.0;
Expand Down
8 changes: 6 additions & 2 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,9 @@ def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:

def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
args.validate_version('2.0')
all_keys = " && ".join(c.qasm for c in self._conditions)
return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args))
if len(self._conditions) > 1:
raise ValueError('QASM does not support multiple conditions.')
subop_qasm = protocols.qasm(self._sub_operation, args=args)
if not self._conditions:
return subop_qasm
return f'if ({self._conditions[0].qasm}) {subop_qasm}'
47 changes: 44 additions & 3 deletions cirq-core/cirq/ops/classically_controlled_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,14 @@ def test_diagram_subcircuit_layered():

def test_qasm():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.measure(q0, key='a'), cirq.X(q1).with_classical_controls('a'))
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.X(q1).with_classical_controls(sympy.Eq(sympy.Symbol('a'), 0)),
)
qasm = cirq.qasm(circuit)
assert (
qasm
== """// Generated from Cirq v0.15.0.dev
== f"""// Generated from Cirq v{cirq.__version__}

OPENQASM 2.0;
include "qelib1.inc";
Expand All @@ -214,11 +217,49 @@ def test_qasm():


measure q[0] -> m_a[0];
if (m_a!=0) x q[1];
if (m_a==0) x q[1];
"""
)


def test_qasm_no_conditions():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'), cirq.ClassicallyControlledOperation(cirq.X(q1), [])
)
qasm = cirq.qasm(circuit)
assert (
qasm
== f"""// Generated from Cirq v{cirq.__version__}

OPENQASM 2.0;
include "qelib1.inc";


// Qubits: [q(0), q(1)]
qreg q[2];
creg m_a[1];


measure q[0] -> m_a[0];
x q[1];
"""
)


def test_qasm_multiple_conditions():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q0, key='b'),
cirq.X(q1).with_classical_controls(
sympy.Eq(sympy.Symbol('a'), 0), sympy.Eq(sympy.Symbol('b'), 0)
),
)
with pytest.raises(ValueError, match='QASM does not support multiple conditions'):
_ = cirq.qasm(circuit)


@pytest.mark.parametrize('sim', ALL_SIMULATORS)
def test_key_unset(sim):
q0, q1 = cirq.LineQubit.range(2)
Expand Down
10 changes: 6 additions & 4 deletions cirq-core/cirq/value/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def _from_json_dict_(cls, key, **kwargs):

@property
def qasm(self):
if self.index != -1:
raise NotImplementedError('Only most recent measurement at key can be used for QASM.')
return f'm_{self.key}!=0'
raise ValueError('QASM is defined only for SympyConditions of type key == constant.')


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -162,4 +160,8 @@ def _from_json_dict_(cls, expr, **kwargs):

@property
def qasm(self):
raise NotImplementedError()
if isinstance(self.expr, sympy.Equality):
if isinstance(self.expr.lhs, sympy.Symbol) and isinstance(self.expr.rhs, sympy.Integer):
# Measurements get prepended with "m_", so the condition needs to be too.
return f'm_{self.expr.lhs}=={self.expr.rhs}'
raise ValueError('QASM is defined only for SympyConditions of type key == constant.')
11 changes: 8 additions & 3 deletions cirq-core/cirq/value/condition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def resolve(records):


def test_key_condition_qasm():
assert cirq.KeyCondition(cirq.MeasurementKey('a')).qasm == 'm_a!=0'
with pytest.raises(ValueError, match='QASM is defined only for SympyConditions'):
_ = cirq.KeyCondition(cirq.MeasurementKey('a')).qasm


def test_sympy_condition_with_keys():
Expand Down Expand Up @@ -111,5 +112,9 @@ def resolve(records):


def test_sympy_condition_qasm():
with pytest.raises(NotImplementedError):
_ = init_sympy_condition.qasm
# Measurements get prepended with "m_", so the condition needs to be too.
assert cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), 2)).qasm == 'm_a==2'
with pytest.raises(
ValueError, match='QASM is defined only for SympyConditions of type key == constant'
):
_ = cirq.SympyCondition(sympy.Symbol('a') != 2).qasm