Skip to content

Commit 3a835dd

Browse files
authored
Add a new sweep metadata from/to proto approach (#6925)
Continue #6869 This PR added a new path to encode and decode metadata to this corresponding proto. Notice this is a side path if the metadata in sweep is not encoded as `Metadata` class, the code path is still old and same for `from_proto` case. If the proto does not contain `metadata` proto field, the code still convert back as the old style
1 parent fe89b32 commit 3a835dd

File tree

11 files changed

+201
-45
lines changed

11 files changed

+201
-45
lines changed

cirq-google/cirq_google/api/v2/run_context.proto

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ message Metadata {
136136
optional string label = 2;
137137

138138
// If true, store this sweep as parameters instead of the independent axes.
139-
optional bool as_parameter = 3;
139+
optional bool is_const = 3;
140+
141+
// A temporary solution that we store the unit information here.
142+
optional string unit = 4;
140143
}
141144

142145
// A bundle of multiple DeviceParameters and their values.

cirq-google/cirq_google/api/v2/run_context_pb2.py

Lines changed: 18 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cirq-google/cirq_google/api/v2/run_context_pb2.pyi

Lines changed: 12 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cirq-google/cirq_google/api/v2/sweeps.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import cirq
2121
from cirq.study import sweeps
2222
from cirq_google.api.v2 import run_context_pb2
23-
from cirq_google.study.device_parameter import DeviceParameter
23+
from cirq_google.study.device_parameter import DeviceParameter, Metadata
2424

2525

2626
def _build_sweep_const(value: Any) -> run_context_pb2.ConstValue:
@@ -116,13 +116,17 @@ def sweep_to_proto(
116116
out.single_sweep.linspace.first_point = sweep.start
117117
out.single_sweep.linspace.last_point = sweep.stop
118118
out.single_sweep.linspace.num_points = sweep.length
119-
# Use duck-typing to support google-internal Parameter objects
120-
if sweep.metadata and getattr(sweep.metadata, 'path', None):
121-
out.single_sweep.parameter.path.extend(sweep.metadata.path)
122-
if sweep.metadata and getattr(sweep.metadata, 'idx', None):
123-
out.single_sweep.parameter.idx = sweep.metadata.idx
124-
if sweep.metadata and getattr(sweep.metadata, 'units', None):
125-
out.single_sweep.parameter.units = sweep.metadata.units
119+
# Encode the metadata if present
120+
if isinstance(sweep.metadata, Metadata):
121+
out.single_sweep.metadata.MergeFrom(metadata_to_proto(sweep.metadata))
122+
else:
123+
# Use duck-typing to support google-internal Parameter objects
124+
if sweep.metadata and getattr(sweep.metadata, 'path', None):
125+
out.single_sweep.parameter.path.extend(sweep.metadata.path)
126+
if sweep.metadata and getattr(sweep.metadata, 'idx', None):
127+
out.single_sweep.parameter.idx = sweep.metadata.idx
128+
if sweep.metadata and getattr(sweep.metadata, 'units', None):
129+
out.single_sweep.parameter.units = sweep.metadata.units
126130
elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr):
127131
sweep = cast(cirq.Points, sweep_transformer(sweep))
128132
out.single_sweep.parameter_key = sweep.key
@@ -135,13 +139,17 @@ def sweep_to_proto(
135139
unit.to_proto(out.single_sweep.points.unit)
136140
else:
137141
out.single_sweep.points.points.extend(sweep.points)
138-
# Use duck-typing to support google-internal Parameter objects
139-
if sweep.metadata and getattr(sweep.metadata, 'path', None):
140-
out.single_sweep.parameter.path.extend(sweep.metadata.path)
141-
if sweep.metadata and getattr(sweep.metadata, 'idx', None):
142-
out.single_sweep.parameter.idx = sweep.metadata.idx
143-
if sweep.metadata and getattr(sweep.metadata, 'units', None):
144-
out.single_sweep.parameter.units = sweep.metadata.units
142+
# Encode the metadata if present
143+
if isinstance(sweep.metadata, Metadata):
144+
out.single_sweep.metadata.MergeFrom(metadata_to_proto(sweep.metadata))
145+
else:
146+
# Use duck-typing to support google-internal Parameter objects
147+
if sweep.metadata and getattr(sweep.metadata, 'path', None):
148+
out.single_sweep.parameter.path.extend(sweep.metadata.path)
149+
if sweep.metadata and getattr(sweep.metadata, 'idx', None):
150+
out.single_sweep.parameter.idx = sweep.metadata.idx
151+
if sweep.metadata and getattr(sweep.metadata, 'units', None):
152+
out.single_sweep.parameter.units = sweep.metadata.units
145153
elif isinstance(sweep, cirq.ListSweep):
146154
sweep_dict: Dict[str, List[float]] = {}
147155
for param_resolver in sweep:
@@ -190,6 +198,7 @@ def sweep_from_proto(
190198
raise ValueError(f'invalid sweep function type: {func_type}')
191199
if which == 'single_sweep':
192200
key = msg.single_sweep.parameter_key
201+
metadata: DeviceParameter | Metadata | None
193202
if msg.single_sweep.HasField("parameter"):
194203
metadata = DeviceParameter(
195204
path=msg.single_sweep.parameter.path,
@@ -204,6 +213,8 @@ def sweep_from_proto(
204213
else None
205214
),
206215
)
216+
elif msg.single_sweep.HasField("metadata"):
217+
metadata = metadata_from_proto(msg.single_sweep.metadata)
207218
else:
208219
metadata = None
209220

@@ -245,6 +256,38 @@ def sweep_from_proto(
245256
raise ValueError(f'sweep type not set: {msg}') # pragma: no cover
246257

247258

259+
def metadata_to_proto(metadata: Metadata) -> run_context_pb2.Metadata:
260+
"""Convert the metadata dataclass to the metadata proto."""
261+
device_parameters: list[run_context_pb2.DeviceParameter] = []
262+
if params := getattr(metadata, "device_parameters", None):
263+
for param in params:
264+
path = getattr(param, "path", None)
265+
idx = getattr(param, "idx", None)
266+
device_parameters.append(run_context_pb2.DeviceParameter(path=path, idx=idx))
267+
268+
return run_context_pb2.Metadata(
269+
device_parameters=device_parameters or None, # If empty set this field as None.
270+
label=metadata.label,
271+
is_const=metadata.is_const,
272+
unit=metadata.unit,
273+
)
274+
275+
276+
def metadata_from_proto(metadata_pb: run_context_pb2.Metadata) -> Metadata:
277+
"""Convert the metadata proto to the metadata dataclass."""
278+
device_parameters: list[DeviceParameter] = []
279+
for param in metadata_pb.device_parameters:
280+
device_parameters.append(
281+
DeviceParameter(path=param.path, idx=param.idx if param.HasField("idx") else None)
282+
)
283+
return Metadata(
284+
device_parameters=device_parameters or None,
285+
label=metadata_pb.label if metadata_pb.HasField("label") else None,
286+
is_const=metadata_pb.is_const,
287+
unit=metadata_pb.unit if metadata_pb.HasField("unit") else None,
288+
)
289+
290+
248291
def run_context_to_proto(
249292
sweepable: cirq.Sweepable, repetitions: int, *, out: Optional[run_context_pb2.RunContext] = None
250293
) -> run_context_pb2.RunContext:

cirq-google/cirq_google/api/v2/sweeps_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import cirq
2424
from cirq.study import sweeps
25-
from cirq_google.study import DeviceParameter
25+
from cirq_google.study import DeviceParameter, Metadata
2626
from cirq_google.api import v2
2727

2828

@@ -55,6 +55,27 @@ def _values(self) -> Iterator[float]:
5555
[1, 1.5, 2, 2.5, 3],
5656
metadata=DeviceParameter(path=['path', 'to', 'parameter'], idx=2, units='GHz'),
5757
),
58+
cirq.Points(
59+
'a',
60+
[1, 1.5, 2, 2.5, 3],
61+
metadata=Metadata(
62+
device_parameters=[DeviceParameter(path=['path', 'to', 'parameter'], idx=2)],
63+
label="bb",
64+
),
65+
),
66+
cirq.Points(
67+
'a',
68+
[1],
69+
metadata=Metadata(
70+
device_parameters=[
71+
DeviceParameter(path=['path', 'to', 'parameter']),
72+
DeviceParameter(path=['path', 'to', 'parameter2']),
73+
],
74+
label="bb",
75+
is_const=True,
76+
),
77+
),
78+
cirq.Linspace('a', 0, 10, 100, metadata=Metadata(is_const=True)),
5879
cirq.Points(
5980
'b',
6081
[1, 1.5, 2, 2.5, 3],

cirq-google/cirq_google/json_resolver_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,6 @@ def _old_xmon(*args, **kwargs):
7878
'cirq.google.GridDevice': cirq_google.GridDevice,
7979
'cirq.google.GoogleCZTargetGateset': cirq_google.GoogleCZTargetGateset,
8080
'cirq.google.DeviceParameter': cirq_google.study.device_parameter.DeviceParameter,
81+
'cirq.google.Metadata': cirq_google.study.device_parameter.Metadata,
8182
'InternalGate': cirq_google.InternalGate,
8283
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"cirq_type":"cirq.google.Metadata",
3+
"device_parameters":[
4+
{
5+
"cirq_type":"cirq.google.DeviceParameter",
6+
"path":[
7+
"test",
8+
"subdir"
9+
],
10+
"idx":null,
11+
"value":null,
12+
"units":null
13+
}
14+
],
15+
"is_const":true,
16+
"label":"fake_label",
17+
"unit":null
18+
}

0 commit comments

Comments
 (0)