Skip to content

Commit dbada0e

Browse files
authored
Add bitwise option to Sympy evaluator for classical controls (#6914)
1 parent f39aff1 commit dbada0e

21 files changed

+300
-19
lines changed

cirq-core/cirq/_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ def proper_repr(value: Any) -> str:
151151
'StrictLessThan',
152152
'Equality',
153153
'Unequality',
154+
'And',
155+
'Or',
156+
'Not',
157+
'Xor',
158+
'Indexed',
159+
'IndexedBase',
154160
]
155161

156162
class Printer(sympy.printing.repr.ReprPrinter):

cirq-core/cirq/json_resolver_cache.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@ def _symmetricalqidpair(qids):
273273
'sympy.StrictLessThan': lambda args: sympy.StrictLessThan(*args),
274274
'sympy.Equality': lambda args: sympy.Equality(*args),
275275
'sympy.Unequality': lambda args: sympy.Unequality(*args),
276+
'sympy.And': lambda args: sympy.And(*args),
277+
'sympy.Or': lambda args: sympy.Or(*args),
278+
'sympy.Not': lambda args: sympy.Not(*args),
279+
'sympy.Xor': lambda args: sympy.Xor(*args),
280+
'sympy.Indexed': lambda args: sympy.Indexed(*args),
281+
'sympy.IndexedBase': lambda args: sympy.IndexedBase(*args),
276282
'sympy.Float': lambda approx: sympy.Float(approx),
277283
'sympy.Integer': sympy.Integer,
278284
'sympy.Rational': sympy.Rational,

cirq-core/cirq/ops/classically_controlled_operation_test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,3 +1127,65 @@ def test_diagram_exponents_multiple_keys():
11271127
└──┘
11281128
""",
11291129
)
1130+
1131+
1132+
def test_sympy_indexed_condition_circuit():
1133+
a = sympy.IndexedBase('a')
1134+
# XOR the 2nd and 3rd bits of the measurement (big-endian)
1135+
cond = cirq.SympyCondition(sympy.Xor(a[1], a[2]))
1136+
q0, q1, q2, q3 = cirq.LineQubit.range(4)
1137+
sim = cirq.Simulator()
1138+
circuit = cirq.Circuit(
1139+
cirq.measure(q0, q1, q2, key='a'),
1140+
cirq.X(q3).with_classical_controls(cond),
1141+
cirq.measure(q3, key='b'),
1142+
)
1143+
cirq.testing.assert_has_diagram(
1144+
circuit,
1145+
"""
1146+
0: ───M──────────────────────────────────────────
1147+
1148+
1: ───M──────────────────────────────────────────
1149+
1150+
2: ───M──────────────────────────────────────────
1151+
1152+
3: ───╫───X(conditions=[a[1] ^ a[2]])───M('b')───
1153+
║ ║
1154+
a: ═══@═══^══════════════════════════════════════
1155+
""",
1156+
)
1157+
result = sim.sample(circuit)
1158+
assert result['a'][0] == 0b000
1159+
assert result['b'][0] == 0
1160+
circuit.insert(0, cirq.X(q2))
1161+
result = sim.sample(circuit)
1162+
assert result['a'][0] == 0b001
1163+
assert result['b'][0] == 1
1164+
circuit.insert(0, cirq.X(q1))
1165+
circuit.insert(0, cirq.X(q2))
1166+
result = sim.sample(circuit)
1167+
assert result['a'][0] == 0b010
1168+
assert result['b'][0] == 1
1169+
circuit.insert(0, cirq.X(q2))
1170+
result = sim.sample(circuit)
1171+
assert result['a'][0] == 0b011
1172+
assert result['b'][0] == 0
1173+
circuit.insert(0, cirq.X(q0))
1174+
circuit.insert(0, cirq.X(q1))
1175+
circuit.insert(0, cirq.X(q2))
1176+
result = sim.sample(circuit)
1177+
assert result['a'][0] == 0b100
1178+
assert result['b'][0] == 0
1179+
circuit.insert(0, cirq.X(q2))
1180+
result = sim.sample(circuit)
1181+
assert result['a'][0] == 0b101
1182+
assert result['b'][0] == 1
1183+
circuit.insert(0, cirq.X(q1))
1184+
circuit.insert(0, cirq.X(q2))
1185+
result = sim.sample(circuit)
1186+
assert result['a'][0] == 0b110
1187+
assert result['b'][0] == 1
1188+
circuit.insert(0, cirq.X(q2))
1189+
result = sim.sample(circuit)
1190+
assert result['a'][0] == 0b111
1191+
assert result['b'][0] == 0

cirq-core/cirq/protocols/hash_from_pickle_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@
4343
"cirq/protocols/json_test_data/sympy.StrictLessThan.json",
4444
"cirq/protocols/json_test_data/sympy.Symbol.json",
4545
"cirq/protocols/json_test_data/sympy.Unequality.json",
46+
"cirq/protocols/json_test_data/sympy.And.json",
47+
"cirq/protocols/json_test_data/sympy.Not.json",
48+
"cirq/protocols/json_test_data/sympy.Or.json",
49+
"cirq/protocols/json_test_data/sympy.Xor.json",
50+
"cirq/protocols/json_test_data/sympy.Indexed.json",
51+
"cirq/protocols/json_test_data/sympy.IndexedBase.json",
4652
"cirq/protocols/json_test_data/sympy.pi.json",
4753
# RigettiQCSAspenDevice does not pickle
4854
"cirq_rigetti/json_test_data/RigettiQCSAspenDevice.json",

cirq-core/cirq/protocols/json_serialization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ def default(self, o):
255255
sympy.StrictLessThan,
256256
sympy.Equality,
257257
sympy.Unequality,
258+
sympy.And,
259+
sympy.Or,
260+
sympy.Not,
261+
sympy.Xor,
262+
sympy.Indexed,
263+
sympy.IndexedBase,
258264
),
259265
):
260266
return {'cirq_type': f'sympy.{o.__class__.__name__}', 'args': o.args}
Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,62 @@
1-
{
2-
"cirq_type": "SympyCondition",
3-
"expr":
1+
[
42
{
5-
"cirq_type": "sympy.GreaterThan",
6-
"args": [
7-
{
8-
"cirq_type": "sympy.Symbol",
9-
"name": "a"
10-
},
11-
{
12-
"cirq_type": "sympy.Symbol",
13-
"name": "b"
14-
}
15-
]
3+
"cirq_type": "SympyCondition",
4+
"expr": {
5+
"cirq_type": "sympy.GreaterThan",
6+
"args": [
7+
{
8+
"cirq_type": "sympy.Symbol",
9+
"name": "a"
10+
},
11+
{
12+
"cirq_type": "sympy.Symbol",
13+
"name": "b"
14+
}
15+
]
16+
}
17+
},
18+
{
19+
"cirq_type": "SympyCondition",
20+
"expr": {
21+
"cirq_type": "sympy.Xor",
22+
"args": [
23+
{
24+
"cirq_type": "sympy.Indexed",
25+
"args": [
26+
{
27+
"cirq_type": "sympy.IndexedBase",
28+
"args": [
29+
{
30+
"cirq_type": "sympy.Symbol",
31+
"name": "a"
32+
}
33+
]
34+
},
35+
{
36+
"cirq_type": "sympy.Integer",
37+
"i": 0
38+
}
39+
]
40+
},
41+
{
42+
"cirq_type": "sympy.Indexed",
43+
"args": [
44+
{
45+
"cirq_type": "sympy.IndexedBase",
46+
"args": [
47+
{
48+
"cirq_type": "sympy.Symbol",
49+
"name": "a"
50+
}
51+
]
52+
},
53+
{
54+
"cirq_type": "sympy.Integer",
55+
"i": 1
56+
}
57+
]
58+
}
59+
]
60+
}
1661
}
17-
}
62+
]
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
cirq.SympyCondition(sympy.GreaterThan(sympy.Symbol('a'), sympy.Symbol('b')))
1+
[
2+
cirq.SympyCondition(sympy.GreaterThan(sympy.Symbol('a'), sympy.Symbol('b'))),
3+
cirq.SympyCondition(sympy.Xor(sympy.Indexed(sympy.IndexedBase(sympy.Symbol('a')), sympy.Integer(0)), sympy.Indexed(sympy.IndexedBase(sympy.Symbol('a')), sympy.Integer(1))))
4+
]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"cirq_type": "sympy.And",
3+
"args": [
4+
{
5+
"cirq_type": "sympy.Symbol",
6+
"name": "s"
7+
},
8+
{
9+
"cirq_type": "sympy.Symbol",
10+
"name": "t"
11+
}
12+
]
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sympy.And(sympy.Symbol('s'), sympy.Symbol('t'))
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"cirq_type": "sympy.Indexed",
3+
"args": [
4+
{
5+
"cirq_type": "sympy.IndexedBase",
6+
"args": [
7+
{
8+
"cirq_type": "sympy.Symbol",
9+
"name": "s"
10+
}
11+
]
12+
},
13+
{
14+
"cirq_type": "sympy.Integer",
15+
"i": 1
16+
}
17+
]
18+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sympy.Indexed(sympy.IndexedBase(sympy.Symbol('s')),sympy.Integer(1))
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"cirq_type": "sympy.IndexedBase",
3+
"args": [
4+
{
5+
"cirq_type": "sympy.Symbol",
6+
"name": "s"
7+
}
8+
]
9+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sympy.IndexedBase(sympy.Symbol('s'))
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"cirq_type": "sympy.Not",
3+
"args": [
4+
{
5+
"cirq_type": "sympy.Symbol",
6+
"name": "s"
7+
}
8+
]
9+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sympy.Not(sympy.Symbol('s'))
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"cirq_type": "sympy.Or",
3+
"args": [
4+
{
5+
"cirq_type": "sympy.Symbol",
6+
"name": "s"
7+
},
8+
{
9+
"cirq_type": "sympy.Symbol",
10+
"name": "t"
11+
}
12+
]
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sympy.Or(sympy.Symbol('s'), sympy.Symbol('t'))
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"cirq_type": "sympy.Xor",
3+
"args": [
4+
{
5+
"cirq_type": "sympy.Symbol",
6+
"name": "s"
7+
},
8+
{
9+
"cirq_type": "sympy.Symbol",
10+
"name": "t"
11+
}
12+
]
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sympy.Xor(sympy.Symbol('s'), sympy.Symbol('t'))

cirq-core/cirq/value/condition.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import abc
1616
import dataclasses
17-
from typing import Mapping, Tuple, TYPE_CHECKING, FrozenSet, Optional
17+
from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple, TYPE_CHECKING
1818

1919
import sympy
2020

@@ -142,6 +142,12 @@ class SympyCondition(Condition):
142142
This condition resolves to True iff the sympy expression resolves to a
143143
truthy value (i.e. `bool(x) == True`) when the measurement keys are
144144
substituted in as the free variables.
145+
146+
`sympy.IndexedBase` can be used for bitwise conditions. For example, the
147+
following will create a condition that is controlled by the XOR of the
148+
first two bits (big-endian) of measurement 'a'.
149+
>>> a = sympy.IndexedBase('a')
150+
>>> cond = cirq.SympyCondition(sympy.Xor(a[0], a[1]))
145151
"""
146152

147153
expr: sympy.Basic
@@ -151,6 +157,9 @@ def keys(self):
151157
return tuple(
152158
measurement_key.MeasurementKey.parse_serialized(symbol.name)
153159
for symbol in self.expr.free_symbols
160+
if isinstance(symbol, sympy.Symbol)
161+
# For bitwise ops, both Symbol ('a') and Indexed ('a[0]') are returned. We only want to
162+
# keep the former here.
154163
)
155164

156165
def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'):
@@ -167,8 +176,19 @@ def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
167176
if missing:
168177
raise ValueError(f'Measurement keys {missing} missing when testing classical control')
169178

170-
replacements = {str(k): classical_data.get_int(k) for k in self.keys}
171-
return bool(self.expr.subs(replacements))
179+
replacements: Dict[str, Any] = {}
180+
for symbol in self.expr.free_symbols:
181+
if isinstance(symbol, sympy.Symbol):
182+
name = symbol.name
183+
key = measurement_key.MeasurementKey.parse_serialized(name)
184+
replacements[str(key)] = classical_data.get_int(key)
185+
for symbol in self.expr.free_symbols:
186+
if isinstance(symbol, sympy.Indexed):
187+
name = symbol.base.name
188+
key = measurement_key.MeasurementKey.parse_serialized(name)
189+
replacements[str(key)] = tuple(classical_data.get_digits(key))
190+
value = self.expr.subs(replacements)
191+
return bool(value)
172192

173193
def _json_dict_(self):
174194
return json_serialization.dataclass_json_dict(self)

0 commit comments

Comments
 (0)