Skip to content

Commit be8b04b

Browse files
authored
Pass sweep_transformer recursively (#6951)
* fix bug * rm print
1 parent df776a0 commit be8b04b

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

cirq-google/cirq_google/api/v2/sweeps.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,27 @@ def sweep_to_proto(
8282
elif isinstance(sweep, cirq.Product):
8383
out.sweep_function.function_type = run_context_pb2.SweepFunction.PRODUCT
8484
for factor in sweep.factors:
85-
sweep_to_proto(factor, out=out.sweep_function.sweeps.add())
85+
sweep_to_proto(
86+
factor, out=out.sweep_function.sweeps.add(), sweep_transformer=sweep_transformer
87+
)
8688
elif isinstance(sweep, cirq.ZipLongest):
8789
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP_LONGEST
8890
for s in sweep.sweeps:
89-
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
91+
sweep_to_proto(
92+
s, out=out.sweep_function.sweeps.add(), sweep_transformer=sweep_transformer
93+
)
9094
elif isinstance(sweep, cirq.Zip):
9195
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
9296
for s in sweep.sweeps:
93-
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
97+
sweep_to_proto(
98+
s, out=out.sweep_function.sweeps.add(), sweep_transformer=sweep_transformer
99+
)
94100
elif isinstance(sweep, cirq.Concat):
95101
out.sweep_function.function_type = run_context_pb2.SweepFunction.CONCAT
96102
for s in sweep.sweeps:
97-
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
103+
sweep_to_proto(
104+
s, out=out.sweep_function.sweeps.add(), sweep_transformer=sweep_transformer
105+
)
98106
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
99107
sweep = cast(cirq.Linspace, sweep_transformer(sweep))
100108
out.single_sweep.parameter_key = sweep.key
@@ -143,7 +151,11 @@ def sweep_to_proto(
143151
sweep_dict[cast(str, key)].append(cast(float, param_resolver.value_of(key)))
144152
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
145153
for key in sweep_dict:
146-
sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add())
154+
sweep_to_proto(
155+
cirq.Points(key, sweep_dict[key]),
156+
out=out.sweep_function.sweeps.add(),
157+
sweep_transformer=sweep_transformer,
158+
)
147159
else:
148160
raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}')
149161
return out
@@ -164,7 +176,7 @@ def sweep_from_proto(
164176
if which is None:
165177
return cirq.UnitSweep
166178
if which == 'sweep_function':
167-
factors = [sweep_from_proto(m) for m in msg.sweep_function.sweeps]
179+
factors = [sweep_from_proto(m, sweep_transformer) for m in msg.sweep_function.sweeps]
168180
func_type = msg.sweep_function.function_type
169181
if func_type == run_context_pb2.SweepFunction.PRODUCT:
170182
return cirq.Product(*factors)

cirq-google/cirq_google/api/v2/sweeps_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15+
from copy import deepcopy
1516
from typing import Iterator
1617

1718
import pytest
@@ -275,6 +276,64 @@ def strip_tunit_func(sweep: sweeps.SingleSweep):
275276
assert list(sweep.points)[0] == 1.0
276277

277278

279+
@pytest.mark.parametrize(
280+
'sweep',
281+
[
282+
cirq.Concat(cirq.Points('a', [1, 2, 3]), cirq.Points('a', [4])),
283+
cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6]),
284+
cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [1])),
285+
cirq.Zip(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6])),
286+
],
287+
)
288+
def test_sweep_to_proto_with_func_on_resursive_sweep_succeeds(sweep):
289+
def add_tunit_func(sweep: sweeps.SingleSweep):
290+
if isinstance(sweep, cirq.Points):
291+
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]
292+
293+
return sweep
294+
295+
msg = v2.sweep_to_proto(sweep, sweep_transformer=add_tunit_func)
296+
297+
assert msg.sweep_function.sweeps[0].single_sweep.points.unit == tunits.ns.to_proto()
298+
299+
300+
@pytest.mark.parametrize(
301+
'expected_sweep',
302+
[
303+
cirq.Concat(cirq.Points('a', [1.0, 2.0, 3.0]), cirq.Points('a', [4.0])),
304+
cirq.Points('a', [1.0, 2.0, 3.0]) * cirq.Points('b', [4.0, 5.0, 6.0]),
305+
cirq.ZipLongest(cirq.Points('a', [1.0, 2.0, 3.0]), cirq.Points('b', [1.0])),
306+
cirq.Zip(cirq.Points('a', [1.0, 2.0, 3.0]), cirq.Points('b', [4.0, 5.0, 6.0])),
307+
cirq.Points('a', [1, 2, 3])
308+
+ cirq.Points(
309+
'b',
310+
[4, 5, 6],
311+
metadata=DeviceParameter(path=['path', 'to', 'parameter'], idx=2, units='GHz'),
312+
),
313+
],
314+
)
315+
def test_sweep_from_proto_with_func_on_resursive_sweep_succeeds(expected_sweep):
316+
def add_tunit_func(sweep_to_transform: sweeps.SingleSweep):
317+
sweep = deepcopy(sweep_to_transform)
318+
if isinstance(sweep, cirq.Points):
319+
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]
320+
321+
return sweep
322+
323+
def strip_tunit_func(sweep_to_transform: sweeps.SingleSweep):
324+
sweep = deepcopy(sweep_to_transform)
325+
if isinstance(sweep, cirq.Points):
326+
if isinstance(sweep.points[0], tunits.Value):
327+
sweep.points = [point[point.unit] for point in sweep.points]
328+
329+
return sweep
330+
331+
msg = v2.sweep_to_proto(expected_sweep, sweep_transformer=add_tunit_func)
332+
round_trip_sweep = v2.sweep_from_proto(msg, strip_tunit_func)
333+
334+
assert round_trip_sweep == expected_sweep
335+
336+
278337
def test_sweep_with_list_sweep():
279338
ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
280339
proto = v2.sweep_to_proto(ls)

0 commit comments

Comments
 (0)