Skip to content

Commit 2c993f8

Browse files
authored
pass metadata to sweeps (#6644)
* pass metadata to sweeps * format * format the formatter * mypy
1 parent 72f0542 commit 2c993f8

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

cirq-core/cirq/study/sweepable.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Defines which types are Sweepable."""
1616

17-
from typing import Iterable, Iterator, List, Sequence, Union, cast
17+
from typing import Iterable, Iterator, List, Optional, Sequence, Union, cast
1818
import warnings
1919
from typing_extensions import Protocol
2020

@@ -44,12 +44,12 @@ def to_resolvers(sweepable: Sweepable) -> Iterator[ParamResolver]:
4444
yield from sweep
4545

4646

47-
def to_sweeps(sweepable: Sweepable) -> List[Sweep]:
47+
def to_sweeps(sweepable: Sweepable, metadata: Optional[dict] = None) -> List[Sweep]:
4848
"""Converts a Sweepable to a list of Sweeps."""
4949
if sweepable is None:
5050
return [UnitSweep]
5151
if isinstance(sweepable, ParamResolver):
52-
return [_resolver_to_sweep(sweepable)]
52+
return [_resolver_to_sweep(sweepable, metadata)]
5353
if isinstance(sweepable, Sweep):
5454
return [sweepable]
5555
if isinstance(sweepable, dict):
@@ -63,9 +63,13 @@ def to_sweeps(sweepable: Sweepable) -> List[Sweep]:
6363
stacklevel=2,
6464
)
6565
product_sweep = dict_to_product_sweep(sweepable)
66-
return [_resolver_to_sweep(resolver) for resolver in product_sweep]
66+
return [_resolver_to_sweep(resolver, metadata) for resolver in product_sweep]
6767
if isinstance(sweepable, Iterable) and not isinstance(sweepable, str):
68-
return [sweep for item in sweepable for sweep in to_sweeps(item)] # type: ignore[arg-type]
68+
return [
69+
sweep
70+
for item in sweepable
71+
for sweep in to_sweeps(item, metadata) # type: ignore[arg-type]
72+
]
6973
raise TypeError(f'Unrecognized sweepable type: {type(sweepable)}.\nsweepable: {sweepable}')
7074

7175

@@ -98,8 +102,13 @@ def to_sweep(
98102
raise TypeError(f'Unexpected sweep-like value: {sweep_or_resolver_list}')
99103

100104

101-
def _resolver_to_sweep(resolver: ParamResolver) -> Sweep:
105+
def _resolver_to_sweep(resolver: ParamResolver, metadata: Optional[dict]) -> Sweep:
102106
params = resolver.param_dict
103107
if not params:
104108
return UnitSweep
105-
return Zip(*[Points(key, [cast(float, value)]) for key, value in params.items()])
109+
return Zip(
110+
*[
111+
Points(key, [cast(float, value)], metadata=metadata.get(key) if metadata else None)
112+
for key, value in params.items()
113+
]
114+
)

cirq-core/cirq/study/sweepable_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,30 @@ def test_to_sweep_resolver_list(r_list_gen):
147147
def test_to_sweep_type_error():
148148
with pytest.raises(TypeError, match='Unexpected sweep'):
149149
cirq.to_sweep(5)
150+
151+
152+
def test_to_sweeps_with_param_dict_appends_metadata():
153+
params = {'a': 1, 'b': 2, 'c': 3}
154+
unit_map = {'a': 'ns', 'b': 'ns'}
155+
156+
sweep = cirq.to_sweeps(params, unit_map)
157+
158+
assert sweep == [
159+
cirq.Zip(
160+
cirq.Points('a', [1], metadata='ns'),
161+
cirq.Points('b', [2], metadata='ns'),
162+
cirq.Points('c', [3]),
163+
)
164+
]
165+
166+
167+
def test_to_sweeps_with_param_list_appends_metadata():
168+
resolvers = [cirq.ParamResolver({'a': 2}), cirq.ParamResolver({'a': 1})]
169+
unit_map = {'a': 'ns'}
170+
171+
sweeps = cirq.study.to_sweeps(resolvers, unit_map)
172+
173+
assert sweeps == [
174+
cirq.Zip(cirq.Points('a', [2], metadata='ns')),
175+
cirq.Zip(cirq.Points('a', [1], metadata='ns')),
176+
]

0 commit comments

Comments
 (0)