Skip to content

Commit 59ab27a

Browse files
daxfohlmhucka
andauthored
Infer QASM for classical KeyConditions in circuit (#6871)
* Infer qasm for classical keyconditions * add qasm3.0 option * comments --------- Co-authored-by: Michael Hucka <[email protected]>
1 parent 32bf399 commit 59ab27a

File tree

6 files changed

+82
-6
lines changed

6 files changed

+82
-6
lines changed

cirq-core/cirq/circuits/qasm_output.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,14 @@ def __init__(
197197
meas_key_id_map, meas_comments = self._generate_measurement_ids()
198198
self.meas_comments = meas_comments
199199
qubit_id_map = self._generate_qubit_ids()
200+
self.cregs = self._generate_cregs(meas_key_id_map)
200201
self.args = protocols.QasmArgs(
201202
precision=precision,
202203
version=version,
203204
qubit_id_map=qubit_id_map,
204205
meas_key_id_map=meas_key_id_map,
206+
meas_key_bitcount={k: v[0] for k, v in self.cregs.items()},
205207
)
206-
self.cregs = self._generate_cregs()
207208

208209
def _generate_measurement_ids(self) -> Tuple[Dict[str, str], Dict[str, Optional[str]]]:
209210
# Pick an id for the creg that will store each measurement
@@ -227,7 +228,7 @@ def _generate_measurement_ids(self) -> Tuple[Dict[str, str], Dict[str, Optional[
227228
def _generate_qubit_ids(self) -> Dict['cirq.Qid', str]:
228229
return {qubit: f'q[{i}]' for i, qubit in enumerate(self.qubits)}
229230

230-
def _generate_cregs(self) -> Dict[str, tuple[int, str]]:
231+
def _generate_cregs(self, meas_key_id_map: Dict[str, str]) -> Dict[str, tuple[int, str]]:
231232
"""Pick an id for the creg that will store each measurement
232233
233234
This function finds the largest measurement using each key.
@@ -239,7 +240,7 @@ def _generate_cregs(self) -> Dict[str, tuple[int, str]]:
239240
cregs: Dict[str, tuple[int, str]] = {}
240241
for meas in self.measurements:
241242
key = protocols.measurement_key_name(meas)
242-
meas_id = self.args.meas_key_id_map[key]
243+
meas_id = meas_key_id_map[key]
243244

244245
if self.meas_comments[key] is not None:
245246
comment = f' // Measurement: {self.meas_comments[key]}'

cirq-core/cirq/ops/classically_controlled_operation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,4 @@ def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
219219
subop_qasm = protocols.qasm(self._sub_operation, args=args)
220220
if not self._conditions:
221221
return subop_qasm
222-
return f'if ({self._conditions[0].qasm}) {subop_qasm}'
222+
return f'if ({protocols.qasm(self._conditions[0], args=args)}) {subop_qasm}'

cirq-core/cirq/ops/classically_controlled_operation_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def test_diagram_subcircuit_layered():
196196
)
197197

198198

199-
def test_qasm():
199+
def test_qasm_sympy_condition():
200200
q0, q1 = cirq.LineQubit.range(2)
201201
circuit = cirq.Circuit(
202202
cirq.measure(q0, key='a'),
@@ -222,6 +222,29 @@ def test_qasm():
222222
)
223223

224224

225+
def test_qasm_key_condition():
226+
q0, q1 = cirq.LineQubit.range(2)
227+
circuit = cirq.Circuit(cirq.measure(q0, key='a'), cirq.X(q1).with_classical_controls('a'))
228+
qasm = cirq.qasm(circuit)
229+
assert (
230+
qasm
231+
== f"""// Generated from Cirq v{cirq.__version__}
232+
233+
OPENQASM 2.0;
234+
include "qelib1.inc";
235+
236+
237+
// Qubits: [q(0), q(1)]
238+
qreg q[2];
239+
creg m_a[1];
240+
241+
242+
measure q[0] -> m_a[0];
243+
if (m_a==1) x q[1];
244+
"""
245+
)
246+
247+
225248
def test_qasm_no_conditions():
226249
q0, q1 = cirq.LineQubit.range(2)
227250
circuit = cirq.Circuit(

cirq-core/cirq/protocols/qasm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
version: str = '2.0',
3939
qubit_id_map: Optional[Dict['cirq.Qid', str]] = None,
4040
meas_key_id_map: Optional[Dict[str, str]] = None,
41+
meas_key_bitcount: Optional[Dict[str, int]] = None,
4142
) -> None:
4243
"""Inits QasmArgs.
4344
@@ -49,11 +50,14 @@ def __init__(
4950
qubit_id_map: A dictionary mapping qubits to qreg QASM identifiers.
5051
meas_key_id_map: A dictionary mapping measurement keys to creg QASM
5152
identifiers.
53+
meas_key_bitcount: A dictionary with of bits for each measurement
54+
key.
5255
"""
5356
self.precision = precision
5457
self.version = version
5558
self.qubit_id_map = {} if qubit_id_map is None else qubit_id_map
5659
self.meas_key_id_map = {} if meas_key_id_map is None else meas_key_id_map
60+
self.meas_key_bitcount = {} if meas_key_bitcount is None else meas_key_bitcount
5761

5862
def _format_number(self, value) -> str:
5963
"""OpenQASM 2.0 does not support '1e-5' and wants '1.0e-5'"""

cirq-core/cirq/value/condition.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import abc
1616
import dataclasses
17-
from typing import Mapping, Tuple, TYPE_CHECKING, FrozenSet
17+
from typing import Mapping, Tuple, TYPE_CHECKING, FrozenSet, Optional
1818

1919
import sympy
2020

@@ -47,6 +47,9 @@ def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
4747
def qasm(self):
4848
"""Returns the qasm of this condition."""
4949

50+
def _qasm_(self, args: 'cirq.QasmArgs', **kwargs) -> Optional[str]:
51+
return self.qasm
52+
5053
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]) -> 'cirq.Condition':
5154
condition = self
5255
for k in self.keys:
@@ -115,6 +118,22 @@ def _from_json_dict_(cls, key, **kwargs):
115118
def qasm(self):
116119
raise ValueError('QASM is defined only for SympyConditions of type key == constant.')
117120

121+
def _qasm_(self, args: 'cirq.QasmArgs', **kwargs) -> Optional[str]:
122+
args.validate_version('2.0', '3.0')
123+
key_str = str(self.key)
124+
if key_str not in args.meas_key_id_map:
125+
raise ValueError(f'Key "{key_str}" not in QasmArgs.meas_key_id_map.')
126+
key = args.meas_key_id_map[key_str]
127+
# QASM 3.0 supports !=, so we return it directly.
128+
if args.version == '3.0':
129+
return f'{key}!=0'
130+
# QASM 2.0 only has == operator, so we must limit to single-bit measurement keys == 1.
131+
if key not in args.meas_key_bitcount:
132+
raise ValueError(f'Key "{key}" not in QasmArgs.meas_key_bitcount.')
133+
if args.meas_key_bitcount[str(key)] != 1:
134+
raise ValueError('QASM is defined only for single-bit classical conditions.')
135+
return f'{key}==1'
136+
118137

119138
@dataclasses.dataclass(frozen=True)
120139
class SympyCondition(Condition):

cirq-core/cirq/value/condition_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,35 @@ def test_key_condition_qasm():
7171
_ = cirq.KeyCondition(cirq.MeasurementKey('a')).qasm
7272

7373

74+
def test_key_condition_qasm_protocol():
75+
cond = cirq.KeyCondition(cirq.MeasurementKey('a'))
76+
args = cirq.QasmArgs(meas_key_id_map={'a': 'm_a'}, meas_key_bitcount={'m_a': 1})
77+
qasm = cirq.qasm(cond, args=args)
78+
assert qasm == 'm_a==1'
79+
80+
81+
def test_key_condition_qasm_protocol_v3():
82+
cond = cirq.KeyCondition(cirq.MeasurementKey('a'))
83+
args = cirq.QasmArgs(meas_key_id_map={'a': 'm_a'}, version='3.0')
84+
qasm = cirq.qasm(cond, args=args)
85+
assert qasm == 'm_a!=0'
86+
87+
88+
def test_key_condition_qasm_protocol_invalid_args():
89+
cond = cirq.KeyCondition(cirq.MeasurementKey('a'))
90+
args = cirq.QasmArgs()
91+
with pytest.raises(ValueError, match='Key "a" not in QasmArgs.meas_key_id_map.'):
92+
_ = cirq.qasm(cond, args=args)
93+
args = cirq.QasmArgs(meas_key_id_map={'a': 'm_a'})
94+
with pytest.raises(ValueError, match='Key "m_a" not in QasmArgs.meas_key_bitcount.'):
95+
_ = cirq.qasm(cond, args=args)
96+
args = cirq.QasmArgs(meas_key_id_map={'a': 'm_a'}, meas_key_bitcount={'m_a': 2})
97+
with pytest.raises(
98+
ValueError, match='QASM is defined only for single-bit classical conditions.'
99+
):
100+
_ = cirq.qasm(cond, args=args)
101+
102+
74103
def test_sympy_condition_with_keys():
75104
c = init_sympy_condition.replace_key(key_a, key_b)
76105
assert c.keys == (key_b,)

0 commit comments

Comments
 (0)