Skip to content

Support serialization of sweeps with tunits #6829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cirq-google/cirq_google/api/v2/run_context.proto
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
syntax = "proto3";

import "cirq_google/api/v2/program.proto";
import "tunits/proto/tunits.proto";

package cirq.google.api.v2;

Expand Down Expand Up @@ -209,6 +210,8 @@ message SingleSweep {
message Points {
// The values.
repeated float points = 1;

tunits.Value unit = 2;
}

// A range of evenly-spaced values.
Expand All @@ -225,6 +228,8 @@ message Linspace {
// greater than zero. If it is 1, the first_point and last_point must be
// the same.
int64 num_points = 3;

tunits.Value unit = 4;
}

// A constant value.
Expand All @@ -236,5 +241,6 @@ message ConstValue {
float float_value = 2;
int64 int_value = 3;
string string_value = 4;
tunits.Value with_unit_value = 5;
}
}
59 changes: 30 additions & 29 deletions cirq-google/cirq_google/api/v2/run_context_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 20 additions & 5 deletions cirq-google/cirq_google/api/v2/run_context_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 34 additions & 7 deletions cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, cast, Dict, List, Optional

import sympy
import tunits

import cirq
from cirq_google.api.v2 import run_context_pb2
Expand All @@ -31,6 +32,8 @@ def _build_sweep_const(value: Any) -> run_context_pb2.ConstValue:
return run_context_pb2.ConstValue(int_value=value)
elif isinstance(value, str):
return run_context_pb2.ConstValue(string_value=value)
elif isinstance(value, tunits.Value):
return run_context_pb2.ConstValue(with_unit_value=value.to_proto())
else:
raise ValueError(
f"Unsupported type for serializing const sweep: {value=} and {type(value)=}"
Expand All @@ -47,6 +50,8 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any:
return const_pb.int_value
if const_pb.WhichOneof('value') == 'string_value':
return const_pb.string_value
if const_pb.WhichOneof('value') == 'with_unit_value':
return tunits.Value.from_proto(const_pb.with_unit_value)


def sweep_to_proto(
Expand Down Expand Up @@ -87,9 +92,16 @@ def sweep_to_proto(
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
out.single_sweep.parameter_key = sweep.key
out.single_sweep.linspace.first_point = sweep.start
out.single_sweep.linspace.last_point = sweep.stop
out.single_sweep.linspace.num_points = sweep.length
if isinstance(sweep.start, tunits.Value):
unit = sweep.start.unit
out.single_sweep.linspace.first_point = sweep.start[unit]
out.single_sweep.linspace.last_point = sweep.stop[unit]
out.single_sweep.linspace.num_points = sweep.length
unit.to_proto(out.single_sweep.linspace.unit)
else:
out.single_sweep.linspace.first_point = sweep.start
out.single_sweep.linspace.last_point = sweep.stop
out.single_sweep.linspace.num_points = sweep.length
# Use duck-typing to support google-internal Parameter objects
if sweep.metadata and getattr(sweep.metadata, 'path', None):
out.single_sweep.parameter.path.extend(sweep.metadata.path)
Expand All @@ -102,7 +114,12 @@ def sweep_to_proto(
if len(sweep.points) == 1:
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0]))
else:
out.single_sweep.points.points.extend(sweep.points)
if isinstance(sweep.points[0], tunits.Value):
unit = sweep.points[0].unit
out.single_sweep.points.points.extend(p[unit] for p in sweep.points)
unit.to_proto(out.single_sweep.points.unit)
else:
out.single_sweep.points.points.extend(sweep.points)
# Use duck-typing to support google-internal Parameter objects
if sweep.metadata and getattr(sweep.metadata, 'path', None):
out.single_sweep.parameter.path.extend(sweep.metadata.path)
Expand Down Expand Up @@ -162,15 +179,25 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
else:
metadata = None
if msg.single_sweep.WhichOneof('sweep') == 'linspace':
unit: float | tunits.Value = 1.0
if msg.single_sweep.linspace.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.linspace.unit)
return cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point,
stop=msg.single_sweep.linspace.last_point,
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
length=msg.single_sweep.linspace.num_points,
metadata=metadata,
)
if msg.single_sweep.WhichOneof('sweep') == 'points':
return cirq.Points(key=key, points=msg.single_sweep.points.points, metadata=metadata)
unit = 1.0
if msg.single_sweep.points.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.points.unit)
return cirq.Points(
key=key,
points=[p * unit for p in msg.single_sweep.points.points],
metadata=metadata,
)
if msg.single_sweep.WhichOneof('sweep') == 'const_value':
return cirq.Points(
key=key,
Expand Down
Loading