Skip to content

Commit e61e3be

Browse files
authored
Add zip_longest support to the cirq_google sweep proto (#6815)
1 parent 495d913 commit e61e3be

File tree

5 files changed

+49
-21
lines changed

5 files changed

+49
-21
lines changed

cirq-google/cirq_google/api/v2/run_context.proto

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ message SweepFunction {
8585
// "a": 1.0, "b": 3.0
8686
// Note: if one sweep is shorter, the others will be truncated.
8787
ZIP = 2;
88+
89+
// A zip product of parameter sweeps with length as the longest one.
90+
//
91+
// Suppose we zip_longest([sweep.points(a, [1, 2]), sweep.points(b, [3])]),
92+
// the iterator will produce: {a: 1, b: 3} and {a: 2, b: 3}.
93+
// The shorter sweeps will be filled by repeating their last value.
94+
ZIP_LONGEST = 3;
8895
}
8996

9097
FunctionType function_type = 1;

cirq-google/cirq_google/api/v2/run_context_pb2.py

Lines changed: 21 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cirq-google/cirq_google/api/v2/run_context_pb2.pyi

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cirq-google/cirq_google/api/v2/sweeps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def sweep_to_proto(
7373
out.sweep_function.function_type = run_context_pb2.SweepFunction.PRODUCT
7474
for factor in sweep.factors:
7575
sweep_to_proto(factor, out=out.sweep_function.sweeps.add())
76+
elif isinstance(sweep, cirq.ZipLongest):
77+
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP_LONGEST
78+
for s in sweep.sweeps:
79+
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
7680
elif isinstance(sweep, cirq.Zip):
7781
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
7882
for s in sweep.sweeps:
@@ -129,6 +133,8 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
129133
return cirq.Product(*factors)
130134
if func_type == run_context_pb2.SweepFunction.ZIP:
131135
return cirq.Zip(*factors)
136+
if func_type == run_context_pb2.SweepFunction.ZIP_LONGEST:
137+
return cirq.ZipLongest(*factors)
132138

133139
raise ValueError(f'invalid sweep function type: {func_type}')
134140
if which == 'single_sweep':

cirq-google/cirq_google/api/v2/sweeps_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def _values(self) -> Iterator[float]:
6868
+ (cirq.Points('g', [1, 2]) * cirq.Points('h', [-1, 0, 1]))
6969
)
7070
),
71+
cirq.ZipLongest(cirq.Points('a', [1.0, 2.0, 3.0]), cirq.Points('b', [1.0])),
7172
# Sweep with constant. Type ignore is because cirq.Points type annotated with floats.
7273
cirq.Points('a', [None]), # type: ignore[list-item]
7374
cirq.Points('a', [None]) * cirq.Points('b', [1, 2, 3]), # type: ignore[list-item]

0 commit comments

Comments
 (0)