Skip to content

Allow passing a callable to de/serialization funcs #6855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, cast, Dict, List, Optional
from typing import Any, cast, Dict, List, Optional, Callable

import copy
import sympy
import tunits

Expand Down Expand Up @@ -55,14 +56,18 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any:


def sweep_to_proto(
sweep: cirq.Sweep, *, out: Optional[run_context_pb2.Sweep] = None
sweep: cirq.Sweep,
*,
out: Optional[run_context_pb2.Sweep] = None,
func: Callable[..., cirq.Sweep] | None = None,
) -> run_context_pb2.Sweep:
"""Converts a Sweep to v2 protobuf message.

Args:
sweep: The sweep to convert.
out: Optional message to be populated. If not given, a new message will
be created.
func: A function called on Linspace, Points.

Returns:
Populated sweep protobuf message.
Expand Down Expand Up @@ -91,6 +96,16 @@ def sweep_to_proto(
for s in sweep.sweeps:
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
if func:
try:
copied_linspace: cirq.Sweep = func(copy.deepcopy(sweep))
sweep = cast(cirq.Linspace, copied_linspace)
except Exception as e:
print(
f"The function {func} was not applied to {sweep}."
f" because there was an exception thrown: {str(e)}."
)

out.single_sweep.parameter_key = sweep.key
if isinstance(sweep.start, tunits.Value):
unit = sweep.start.unit
Expand All @@ -110,6 +125,15 @@ def sweep_to_proto(
if sweep.metadata and getattr(sweep.metadata, 'units', None):
out.single_sweep.parameter.units = sweep.metadata.units
elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr):
if func:
try:
copied_points: cirq.Sweep = func(copy.deepcopy(sweep))
sweep = cast(cirq.Points, copied_points)
except Exception as e:
print(
f"The function {func} was not applied to {sweep}."
f" because there was an exception thrown: {str(e)}."
)
out.single_sweep.parameter_key = sweep.key
if len(sweep.points) == 1:
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0]))
Expand Down Expand Up @@ -142,8 +166,16 @@ def sweep_to_proto(
return out


def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message."""
def sweep_from_proto(
msg: run_context_pb2.Sweep, func: Callable[..., cirq.Sweep] | None = None
) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message.

Args:
msg: Serialized sweep message.
func: A function called on Linspace, Point, and ConstValue.

"""
which = msg.WhichOneof('sweep')
if which is None:
return cirq.UnitSweep
Expand Down Expand Up @@ -178,11 +210,13 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
)
else:
metadata = None

sweep: cirq.Sweep | None = None
if msg.single_sweep.WhichOneof('sweep') == 'linspace':
unit: float | tunits.Value = 1.0
if msg.single_sweep.linspace.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.linspace.unit)
return cirq.Linspace(
sweep = cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
Expand All @@ -193,17 +227,32 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
unit = 1.0
if msg.single_sweep.points.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.points.unit)
return cirq.Points(
sweep = cirq.Points(
key=key,
points=[p * unit for p in msg.single_sweep.points.points],
metadata=metadata,
)
if msg.single_sweep.WhichOneof('sweep') == 'const_value':
return cirq.Points(
sweep = cirq.Points(
key=key,
points=[_recover_sweep_const(msg.single_sweep.const_value)],
metadata=metadata,
)
# Allow for a function to modify a copy of the the sweep. If there are
# no exceptions cirq.Point is modified.
try:
if func and sweep:
copied_sweep: cirq.Sweep = func(copy.deepcopy(sweep))
return copied_sweep
except Exception as e:
print(
f"The function {func} was not applied to {sweep}."
f" because there was an exception thrown: {str(e)}."
)

print(sweep)
if sweep:
return sweep

raise ValueError(f'single sweep type not set: {msg}')

Expand Down
120 changes: 120 additions & 0 deletions cirq-google/cirq_google/api/v2/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,85 @@ def test_sweep_to_proto_points():
assert list(proto.single_sweep.points.points) == [-1, 0, 1, 1.5]


def test_sweep_to_proto_with_simple_func_succeeds():
def func(sweep: cirq.Sweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point + 3 for point in sweep.points]

return sweep

sweep = cirq.Points('foo', [1, 2, 3])
proto = v2.sweep_to_proto(sweep, func=func)

assert list(proto.single_sweep.points.points) == [4.0, 5.0, 6.0]


def test_sweep_to_proto_with_func_linspace():
def func(sweep: cirq.Sweep):
return cirq.Linspace('foo', 3 * tunits.ns, 6 * tunits.ns, 3) # type: ignore[arg-type]

sweep = cirq.Linspace('foo', start=1, stop=3, length=3)
proto = v2.sweep_to_proto(sweep, func=func)

assert proto.single_sweep.linspace.first_point == 3.0
assert proto.single_sweep.linspace.last_point == 6.0
assert tunits.Value.from_proto(proto.single_sweep.linspace.unit) == tunits.ns


def test_sweep_to_proto_with_func_const_value():
def func(sweep: cirq.Sweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point + 3 for point in sweep.points]

return sweep

sweep = cirq.Points('foo', points=[1])
proto = v2.sweep_to_proto(sweep, func=func)

assert proto.single_sweep.const_value.int_value == 4


@pytest.mark.parametrize('sweep', [(cirq.Points('foo', [1, 2, 3])), (cirq.Points('foo', [1]))])
def test_sweep_to_proto_with_func_round_trip(sweep):
def add_tunit_func(sweep: cirq.Sweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]

return sweep

proto = v2.sweep_to_proto(sweep, func=add_tunit_func)
recovered = v2.sweep_from_proto(proto)

assert list(recovered.points)[0] == 1 * tunits.ns


@pytest.mark.parametrize(
('sweep', 'expected_points'),
[(cirq.Points('foo', [1, 2, 3]), [1, 2, 3]), (cirq.Points('foo', [1]), [1])],
)
def test_sweep_to_proto_points_with_invalid_func_round_trip(sweep, expected_points):
def raise_error_func(sweep: cirq.Sweep):
raise ValueError("err")

proto = v2.sweep_to_proto(sweep, func=raise_error_func)
recovered = v2.sweep_from_proto(proto)

assert list(recovered.points) == expected_points


def test_sweep_to_proto_linspace_with_invalid_func_round_trip():
def raise_error_func(sweep: cirq.Sweep):
raise ValueError("err")

sweep = cirq.Linspace('foo', start=0, stop=3, length=3)
proto = v2.sweep_to_proto(sweep, func=raise_error_func)
recovered = v2.sweep_from_proto(proto)

assert recovered.start == 0.0
assert recovered.stop == 3.0
assert recovered.length == 3.0


def test_sweep_to_proto_unit():
proto = v2.sweep_to_proto(cirq.UnitSweep)
assert isinstance(proto, v2.run_context_pb2.Sweep)
Expand Down Expand Up @@ -188,6 +267,47 @@ def test_sweep_from_proto_single_sweep_type_not_set():
v2.sweep_from_proto(proto)


@pytest.mark.parametrize('sweep', [cirq.Points('foo', [1, 2, 3]), cirq.Points('foo', [1])])
def test_sweep_from_proto_with_func_succeeds(sweep):
def add_tunit_func(sweep: cirq.Sweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]

return sweep

msg = v2.sweep_to_proto(sweep)
sweep = v2.sweep_from_proto(msg, func=add_tunit_func)

assert list(sweep.points)[0] == [1.0 * tunits.ns]


@pytest.mark.parametrize(
('sweep', 'expected_points'),
[(cirq.Points('foo', [1, 2, 3]), [1, 2, 3]), (cirq.Points('foo', [1]), [1])],
)
def test_sweep_from_proto_with_invalid_func_round_trip(sweep, expected_points):
def raise_error_func(sweep: cirq.Sweep):
raise ValueError("err")

proto = v2.sweep_to_proto(sweep)
recovered = v2.sweep_from_proto(proto, func=raise_error_func)

assert list(recovered.points) == expected_points


def test_sweep_from_proto_linspace_with_invalid_func_round_trip():
def raise_error_func(sweep: cirq.Sweep):
raise ValueError("err")

sweep = cirq.Linspace('foo', start=0, stop=3, length=3)
proto = v2.sweep_to_proto(sweep)
recovered = v2.sweep_from_proto(proto, func=raise_error_func)

assert recovered.start == 0.0
assert recovered.stop == 3.0
assert recovered.length == 3.0


def test_sweep_with_list_sweep():
ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
proto = v2.sweep_to_proto(ls)
Expand Down