Skip to content

Commit 426ba07

Browse files
senecameeksBichengYing
authored andcommitted
allow passing a callable to de/serialization funcs (quantumlib#6855)
* allow passing func to de/serialization funcs * coverage * simplify * typecheck * nit * mypy * comments * comments
1 parent fb0d2a0 commit 426ba07

File tree

2 files changed

+128
-18
lines changed

2 files changed

+128
-18
lines changed

cirq-google/cirq_google/api/v2/sweeps.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, cast, Dict, List, Optional
15+
from typing import Any, cast, Callable, Dict, List, Optional
1616

1717
import sympy
1818
import tunits
1919

2020
import cirq
21+
from cirq.study import sweeps
2122
from cirq_google.api.v2 import run_context_pb2
2223
from cirq_google.study.device_parameter import DeviceParameter
2324

@@ -55,14 +56,18 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any:
5556

5657

5758
def sweep_to_proto(
58-
sweep: cirq.Sweep, *, out: Optional[run_context_pb2.Sweep] = None
59+
sweep: cirq.Sweep,
60+
*,
61+
out: Optional[run_context_pb2.Sweep] = None,
62+
sweep_transformer: Callable[[sweeps.SingleSweep], sweeps.SingleSweep] = lambda x: x,
5963
) -> run_context_pb2.Sweep:
6064
"""Converts a Sweep to v2 protobuf message.
6165
6266
Args:
6367
sweep: The sweep to convert.
6468
out: Optional message to be populated. If not given, a new message will
6569
be created.
70+
sweep_transformer: A function called on Linspace, Points.
6671
6772
Returns:
6873
Populated sweep protobuf message.
@@ -91,6 +96,7 @@ def sweep_to_proto(
9196
for s in sweep.sweeps:
9297
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
9398
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
99+
sweep = cast(cirq.Linspace, sweep_transformer(sweep))
94100
out.single_sweep.parameter_key = sweep.key
95101
if isinstance(sweep.start, tunits.Value):
96102
unit = sweep.start.unit
@@ -110,6 +116,7 @@ def sweep_to_proto(
110116
if sweep.metadata and getattr(sweep.metadata, 'units', None):
111117
out.single_sweep.parameter.units = sweep.metadata.units
112118
elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr):
119+
sweep = cast(cirq.Points, sweep_transformer(sweep))
113120
out.single_sweep.parameter_key = sweep.key
114121
if len(sweep.points) == 1:
115122
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0]))
@@ -142,8 +149,17 @@ def sweep_to_proto(
142149
return out
143150

144151

145-
def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
146-
"""Creates a Sweep from a v2 protobuf message."""
152+
def sweep_from_proto(
153+
msg: run_context_pb2.Sweep,
154+
sweep_transformer: Callable[[sweeps.SingleSweep], sweeps.SingleSweep] = lambda x: x,
155+
) -> cirq.Sweep:
156+
"""Creates a Sweep from a v2 protobuf message.
157+
158+
Args:
159+
msg: Serialized sweep message.
160+
sweep_transformer: A function called on Linspace, Point, and ConstValue.
161+
162+
"""
147163
which = msg.WhichOneof('sweep')
148164
if which is None:
149165
return cirq.UnitSweep
@@ -178,31 +194,38 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
178194
)
179195
else:
180196
metadata = None
197+
181198
if msg.single_sweep.WhichOneof('sweep') == 'linspace':
182199
unit: float | tunits.Value = 1.0
183200
if msg.single_sweep.linspace.HasField('unit'):
184201
unit = tunits.Value.from_proto(msg.single_sweep.linspace.unit)
185-
return cirq.Linspace(
186-
key=key,
187-
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
188-
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
189-
length=msg.single_sweep.linspace.num_points,
190-
metadata=metadata,
202+
return sweep_transformer(
203+
cirq.Linspace(
204+
key=key,
205+
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
206+
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
207+
length=msg.single_sweep.linspace.num_points,
208+
metadata=metadata,
209+
)
191210
)
192211
if msg.single_sweep.WhichOneof('sweep') == 'points':
193212
unit = 1.0
194213
if msg.single_sweep.points.HasField('unit'):
195214
unit = tunits.Value.from_proto(msg.single_sweep.points.unit)
196-
return cirq.Points(
197-
key=key,
198-
points=[p * unit for p in msg.single_sweep.points.points],
199-
metadata=metadata,
215+
return sweep_transformer(
216+
cirq.Points(
217+
key=key,
218+
points=[p * unit for p in msg.single_sweep.points.points],
219+
metadata=metadata,
220+
)
200221
)
201222
if msg.single_sweep.WhichOneof('sweep') == 'const_value':
202-
return cirq.Points(
203-
key=key,
204-
points=[_recover_sweep_const(msg.single_sweep.const_value)],
205-
metadata=metadata,
223+
return sweep_transformer(
224+
cirq.Points(
225+
key=key,
226+
points=[_recover_sweep_const(msg.single_sweep.const_value)],
227+
metadata=metadata,
228+
)
206229
)
207230

208231
raise ValueError(f'single sweep type not set: {msg}')

cirq-google/cirq_google/api/v2/sweeps_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,58 @@ def test_sweep_to_proto_points():
153153
assert list(proto.single_sweep.points.points) == [-1, 0, 1, 1.5]
154154

155155

156+
def test_sweep_to_proto_with_simple_func_succeeds():
157+
def func(sweep: sweeps.SingleSweep):
158+
if isinstance(sweep, cirq.Points):
159+
sweep.points = [point + 3 for point in sweep.points]
160+
161+
return sweep
162+
163+
sweep = cirq.Points('foo', [1, 2, 3])
164+
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)
165+
166+
assert list(proto.single_sweep.points.points) == [4.0, 5.0, 6.0]
167+
168+
169+
def test_sweep_to_proto_with_func_linspace():
170+
def func(sweep: sweeps.SingleSweep):
171+
return cirq.Linspace('foo', 3 * tunits.ns, 6 * tunits.ns, 3) # type: ignore[arg-type]
172+
173+
sweep = cirq.Linspace('foo', start=1, stop=3, length=3)
174+
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)
175+
176+
assert proto.single_sweep.linspace.first_point == 3.0
177+
assert proto.single_sweep.linspace.last_point == 6.0
178+
assert tunits.Value.from_proto(proto.single_sweep.linspace.unit) == tunits.ns
179+
180+
181+
def test_sweep_to_proto_with_func_const_value():
182+
def func(sweep: sweeps.SingleSweep):
183+
if isinstance(sweep, cirq.Points):
184+
sweep.points = [point + 3 for point in sweep.points]
185+
186+
return sweep
187+
188+
sweep = cirq.Points('foo', points=[1])
189+
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)
190+
191+
assert proto.single_sweep.const_value.int_value == 4
192+
193+
194+
@pytest.mark.parametrize('sweep', [(cirq.Points('foo', [1, 2, 3])), (cirq.Points('foo', [1]))])
195+
def test_sweep_to_proto_with_func_round_trip(sweep):
196+
def add_tunit_func(sweep: sweeps.SingleSweep):
197+
if isinstance(sweep, cirq.Points):
198+
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]
199+
200+
return sweep
201+
202+
proto = v2.sweep_to_proto(sweep, sweep_transformer=add_tunit_func)
203+
recovered = v2.sweep_from_proto(proto)
204+
205+
assert list(recovered.points)[0] == 1 * tunits.ns
206+
207+
156208
def test_sweep_to_proto_unit():
157209
proto = v2.sweep_to_proto(cirq.UnitSweep)
158210
assert isinstance(proto, v2.run_context_pb2.Sweep)
@@ -188,6 +240,41 @@ def test_sweep_from_proto_single_sweep_type_not_set():
188240
v2.sweep_from_proto(proto)
189241

190242

243+
@pytest.mark.parametrize('sweep', [cirq.Points('foo', [1, 2, 3]), cirq.Points('foo', [1])])
244+
def test_sweep_from_proto_with_func_succeeds(sweep):
245+
def add_tunit_func(sweep: sweeps.SingleSweep):
246+
if isinstance(sweep, cirq.Points):
247+
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]
248+
249+
return sweep
250+
251+
msg = v2.sweep_to_proto(sweep)
252+
sweep = v2.sweep_from_proto(msg, sweep_transformer=add_tunit_func)
253+
254+
assert list(sweep.points)[0] == [1.0 * tunits.ns]
255+
256+
257+
@pytest.mark.parametrize('sweep', [cirq.Points('foo', [1, 2, 3]), cirq.Points('foo', [1])])
258+
def test_sweep_from_proto_with_func_round_trip(sweep):
259+
def add_tunit_func(sweep: sweeps.SingleSweep):
260+
if isinstance(sweep, cirq.Points):
261+
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]
262+
263+
return sweep
264+
265+
def strip_tunit_func(sweep: sweeps.SingleSweep):
266+
if isinstance(sweep, cirq.Points):
267+
if isinstance(sweep.points[0], tunits.Value):
268+
sweep.points = [point[point.unit] for point in sweep.points]
269+
270+
return sweep
271+
272+
msg = v2.sweep_to_proto(sweep, sweep_transformer=add_tunit_func)
273+
sweep = v2.sweep_from_proto(msg, sweep_transformer=strip_tunit_func)
274+
275+
assert list(sweep.points)[0] == 1.0
276+
277+
191278
def test_sweep_with_list_sweep():
192279
ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
193280
proto = v2.sweep_to_proto(ls)

0 commit comments

Comments
 (0)