diff --git a/cirq-google/cirq_google/api/v2/run_context.proto b/cirq-google/cirq_google/api/v2/run_context.proto index a66a744a2c8..c826a0ad5ea 100644 --- a/cirq-google/cirq_google/api/v2/run_context.proto +++ b/cirq-google/cirq_google/api/v2/run_context.proto @@ -85,6 +85,13 @@ message SweepFunction { // "a": 1.0, "b": 3.0 // Note: if one sweep is shorter, the others will be truncated. ZIP = 2; + + // A zip product of parameter sweeps with length as the longest one. + // + // Suppose we zip_longest([sweep.points(a, [1, 2]), sweep.points(b, [3])]), + // the iterator will produce: {a: 1, b: 3} and {a: 2, b: 3}. + // The shorter sweeps will be filled by repeating their last value. + ZIP_LONGEST = 3; } FunctionType function_type = 1; diff --git a/cirq-google/cirq_google/api/v2/run_context_pb2.py b/cirq-google/cirq_google/api/v2/run_context_pb2.py index 5bc2c4ac6a9..145bb0aee9c 100644 --- a/cirq-google/cirq_google/api/v2/run_context_pb2.py +++ b/cirq-google/cirq_google/api/v2/run_context_pb2.py @@ -14,7 +14,7 @@ from . import program_pb2 as cirq__google_dot_api_dot_v2_dot_program__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$cirq_google/api/v2/run_context.proto\x12\x12\x63irq.google.api.v2\x1a cirq_google/api/v2/program.proto\"\x98\x01\n\nRunContext\x12<\n\x10parameter_sweeps\x18\x01 \x03(\x0b\x32\".cirq.google.api.v2.ParameterSweep\x12L\n\x1a\x64\x65vice_parameters_override\x18\x02 \x01(\x0b\x32(.cirq.google.api.v2.DeviceParametersDiff\"O\n\x0eParameterSweep\x12\x13\n\x0brepetitions\x18\x01 \x01(\x05\x12(\n\x05sweep\x18\x02 \x01(\x0b\x32\x19.cirq.google.api.v2.Sweep\"\x86\x01\n\x05Sweep\x12;\n\x0esweep_function\x18\x01 \x01(\x0b\x32!.cirq.google.api.v2.SweepFunctionH\x00\x12\x37\n\x0csingle_sweep\x18\x02 \x01(\x0b\x32\x1f.cirq.google.api.v2.SingleSweepH\x00\x42\x07\n\x05sweep\"\xc6\x01\n\rSweepFunction\x12\x45\n\rfunction_type\x18\x01 \x01(\x0e\x32..cirq.google.api.v2.SweepFunction.FunctionType\x12)\n\x06sweeps\x18\x02 \x03(\x0b\x32\x19.cirq.google.api.v2.Sweep\"C\n\x0c\x46unctionType\x12\x1d\n\x19\x46UNCTION_TYPE_UNSPECIFIED\x10\x00\x12\x0b\n\x07PRODUCT\x10\x01\x12\x07\n\x03ZIP\x10\x02\"W\n\x0f\x44\x65viceParameter\x12\x0c\n\x04path\x18\x01 \x03(\t\x12\x10\n\x03idx\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x12\n\x05units\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x06\n\x04_idxB\x08\n\x06_units\"\xcf\x03\n\x14\x44\x65viceParametersDiff\x12\x46\n\x06groups\x18\x01 \x03(\x0b\x32\x36.cirq.google.api.v2.DeviceParametersDiff.ResourceGroup\x12>\n\x06params\x18\x02 \x03(\x0b\x32..cirq.google.api.v2.DeviceParametersDiff.Param\x12\x0c\n\x04strs\x18\x04 \x03(\t\x1a-\n\rResourceGroup\x12\x0e\n\x06parent\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\x05\x1a\x36\n\x0cGenericValue\x12\x17\n\x0ftype_descriptor\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x1a\xb9\x01\n\x05Param\x12\x16\n\x0eresource_group\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\x05\x12-\n\x05value\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.ArgValueH\x00\x12N\n\rgeneric_value\x18\x04 \x01(\x0b\x32\x35.cirq.google.api.v2.DeviceParametersDiff.GenericValueH\x00\x42\x0b\n\tparam_val\"\xfc\x01\n\x0bSingleSweep\x12\x15\n\rparameter_key\x18\x01 \x01(\t\x12,\n\x06points\x18\x02 \x01(\x0b\x32\x1a.cirq.google.api.v2.PointsH\x00\x12\x30\n\x08linspace\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.LinspaceH\x00\x12\x35\n\x0b\x63onst_value\x18\x05 \x01(\x0b\x32\x1e.cirq.google.api.v2.ConstValueH\x00\x12\x36\n\tparameter\x18\x04 \x01(\x0b\x32#.cirq.google.api.v2.DeviceParameterB\x07\n\x05sweep\"\x18\n\x06Points\x12\x0e\n\x06points\x18\x01 \x03(\x02\"G\n\x08Linspace\x12\x13\n\x0b\x66irst_point\x18\x01 \x01(\x02\x12\x12\n\nlast_point\x18\x02 \x01(\x02\x12\x12\n\nnum_points\x18\x03 \x01(\x03\"l\n\nConstValue\x12\x11\n\x07is_none\x18\x01 \x01(\x08H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x03H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x42\x07\n\x05valueB2\n\x1d\x63om.google.cirq.google.api.v2B\x0fRunContextProtoP\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$cirq_google/api/v2/run_context.proto\x12\x12\x63irq.google.api.v2\x1a cirq_google/api/v2/program.proto\"\x98\x01\n\nRunContext\x12<\n\x10parameter_sweeps\x18\x01 \x03(\x0b\x32\".cirq.google.api.v2.ParameterSweep\x12L\n\x1a\x64\x65vice_parameters_override\x18\x02 \x01(\x0b\x32(.cirq.google.api.v2.DeviceParametersDiff\"O\n\x0eParameterSweep\x12\x13\n\x0brepetitions\x18\x01 \x01(\x05\x12(\n\x05sweep\x18\x02 \x01(\x0b\x32\x19.cirq.google.api.v2.Sweep\"\x86\x01\n\x05Sweep\x12;\n\x0esweep_function\x18\x01 \x01(\x0b\x32!.cirq.google.api.v2.SweepFunctionH\x00\x12\x37\n\x0csingle_sweep\x18\x02 \x01(\x0b\x32\x1f.cirq.google.api.v2.SingleSweepH\x00\x42\x07\n\x05sweep\"\xd7\x01\n\rSweepFunction\x12\x45\n\rfunction_type\x18\x01 \x01(\x0e\x32..cirq.google.api.v2.SweepFunction.FunctionType\x12)\n\x06sweeps\x18\x02 \x03(\x0b\x32\x19.cirq.google.api.v2.Sweep\"T\n\x0c\x46unctionType\x12\x1d\n\x19\x46UNCTION_TYPE_UNSPECIFIED\x10\x00\x12\x0b\n\x07PRODUCT\x10\x01\x12\x07\n\x03ZIP\x10\x02\x12\x0f\n\x0bZIP_LONGEST\x10\x03\"W\n\x0f\x44\x65viceParameter\x12\x0c\n\x04path\x18\x01 \x03(\t\x12\x10\n\x03idx\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x12\n\x05units\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x06\n\x04_idxB\x08\n\x06_units\"\xcf\x03\n\x14\x44\x65viceParametersDiff\x12\x46\n\x06groups\x18\x01 \x03(\x0b\x32\x36.cirq.google.api.v2.DeviceParametersDiff.ResourceGroup\x12>\n\x06params\x18\x02 \x03(\x0b\x32..cirq.google.api.v2.DeviceParametersDiff.Param\x12\x0c\n\x04strs\x18\x04 \x03(\t\x1a-\n\rResourceGroup\x12\x0e\n\x06parent\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\x05\x1a\x36\n\x0cGenericValue\x12\x17\n\x0ftype_descriptor\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x1a\xb9\x01\n\x05Param\x12\x16\n\x0eresource_group\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\x05\x12-\n\x05value\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.ArgValueH\x00\x12N\n\rgeneric_value\x18\x04 \x01(\x0b\x32\x35.cirq.google.api.v2.DeviceParametersDiff.GenericValueH\x00\x42\x0b\n\tparam_val\"\xfc\x01\n\x0bSingleSweep\x12\x15\n\rparameter_key\x18\x01 \x01(\t\x12,\n\x06points\x18\x02 \x01(\x0b\x32\x1a.cirq.google.api.v2.PointsH\x00\x12\x30\n\x08linspace\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.LinspaceH\x00\x12\x35\n\x0b\x63onst_value\x18\x05 \x01(\x0b\x32\x1e.cirq.google.api.v2.ConstValueH\x00\x12\x36\n\tparameter\x18\x04 \x01(\x0b\x32#.cirq.google.api.v2.DeviceParameterB\x07\n\x05sweep\"\x18\n\x06Points\x12\x0e\n\x06points\x18\x01 \x03(\x02\"G\n\x08Linspace\x12\x13\n\x0b\x66irst_point\x18\x01 \x01(\x02\x12\x12\n\nlast_point\x18\x02 \x01(\x02\x12\x12\n\nnum_points\x18\x03 \x01(\x03\"l\n\nConstValue\x12\x11\n\x07is_none\x18\x01 \x01(\x08H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x03H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x42\x07\n\x05valueB2\n\x1d\x63om.google.cirq.google.api.v2B\x0fRunContextProtoP\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -29,25 +29,25 @@ _globals['_SWEEP']._serialized_start=331 _globals['_SWEEP']._serialized_end=465 _globals['_SWEEPFUNCTION']._serialized_start=468 - _globals['_SWEEPFUNCTION']._serialized_end=666 + _globals['_SWEEPFUNCTION']._serialized_end=683 _globals['_SWEEPFUNCTION_FUNCTIONTYPE']._serialized_start=599 - _globals['_SWEEPFUNCTION_FUNCTIONTYPE']._serialized_end=666 - _globals['_DEVICEPARAMETER']._serialized_start=668 - _globals['_DEVICEPARAMETER']._serialized_end=755 - _globals['_DEVICEPARAMETERSDIFF']._serialized_start=758 - _globals['_DEVICEPARAMETERSDIFF']._serialized_end=1221 - _globals['_DEVICEPARAMETERSDIFF_RESOURCEGROUP']._serialized_start=932 - _globals['_DEVICEPARAMETERSDIFF_RESOURCEGROUP']._serialized_end=977 - _globals['_DEVICEPARAMETERSDIFF_GENERICVALUE']._serialized_start=979 - _globals['_DEVICEPARAMETERSDIFF_GENERICVALUE']._serialized_end=1033 - _globals['_DEVICEPARAMETERSDIFF_PARAM']._serialized_start=1036 - _globals['_DEVICEPARAMETERSDIFF_PARAM']._serialized_end=1221 - _globals['_SINGLESWEEP']._serialized_start=1224 - _globals['_SINGLESWEEP']._serialized_end=1476 - _globals['_POINTS']._serialized_start=1478 - _globals['_POINTS']._serialized_end=1502 - _globals['_LINSPACE']._serialized_start=1504 - _globals['_LINSPACE']._serialized_end=1575 - _globals['_CONSTVALUE']._serialized_start=1577 - _globals['_CONSTVALUE']._serialized_end=1685 + _globals['_SWEEPFUNCTION_FUNCTIONTYPE']._serialized_end=683 + _globals['_DEVICEPARAMETER']._serialized_start=685 + _globals['_DEVICEPARAMETER']._serialized_end=772 + _globals['_DEVICEPARAMETERSDIFF']._serialized_start=775 + _globals['_DEVICEPARAMETERSDIFF']._serialized_end=1238 + _globals['_DEVICEPARAMETERSDIFF_RESOURCEGROUP']._serialized_start=949 + _globals['_DEVICEPARAMETERSDIFF_RESOURCEGROUP']._serialized_end=994 + _globals['_DEVICEPARAMETERSDIFF_GENERICVALUE']._serialized_start=996 + _globals['_DEVICEPARAMETERSDIFF_GENERICVALUE']._serialized_end=1050 + _globals['_DEVICEPARAMETERSDIFF_PARAM']._serialized_start=1053 + _globals['_DEVICEPARAMETERSDIFF_PARAM']._serialized_end=1238 + _globals['_SINGLESWEEP']._serialized_start=1241 + _globals['_SINGLESWEEP']._serialized_end=1493 + _globals['_POINTS']._serialized_start=1495 + _globals['_POINTS']._serialized_end=1519 + _globals['_LINSPACE']._serialized_start=1521 + _globals['_LINSPACE']._serialized_end=1592 + _globals['_CONSTVALUE']._serialized_start=1594 + _globals['_CONSTVALUE']._serialized_end=1702 # @@protoc_insertion_point(module_scope) diff --git a/cirq-google/cirq_google/api/v2/run_context_pb2.pyi b/cirq-google/cirq_google/api/v2/run_context_pb2.pyi index 896804c19d1..99618d75629 100644 --- a/cirq-google/cirq_google/api/v2/run_context_pb2.pyi +++ b/cirq-google/cirq_google/api/v2/run_context_pb2.pyi @@ -156,6 +156,13 @@ class SweepFunction(google.protobuf.message.Message): "a": 1.0, "b": 3.0 Note: if one sweep is shorter, the others will be truncated. """ + ZIP_LONGEST: SweepFunction._FunctionType.ValueType # 3 + """A zip product of parameter sweeps with length as the longest one. + + Suppose we zip_longest([sweep.points(a, [1, 2]), sweep.points(b, [3])]), + the iterator will produce: {a: 1, b: 3} and {a: 2, b: 3}. + The shorter sweeps will be filled by repeating their last value. + """ class FunctionType(_FunctionType, metaclass=_FunctionTypeEnumTypeWrapper): """The type of sweep function.""" @@ -193,6 +200,13 @@ class SweepFunction(google.protobuf.message.Message): "a": 1.0, "b": 3.0 Note: if one sweep is shorter, the others will be truncated. """ + ZIP_LONGEST: SweepFunction.FunctionType.ValueType # 3 + """A zip product of parameter sweeps with length as the longest one. + + Suppose we zip_longest([sweep.points(a, [1, 2]), sweep.points(b, [3])]), + the iterator will produce: {a: 1, b: 3} and {a: 2, b: 3}. + The shorter sweeps will be filled by repeating their last value. + """ FUNCTION_TYPE_FIELD_NUMBER: builtins.int SWEEPS_FIELD_NUMBER: builtins.int diff --git a/cirq-google/cirq_google/api/v2/sweeps.py b/cirq-google/cirq_google/api/v2/sweeps.py index d3b71fe718a..d2a39b926d8 100644 --- a/cirq-google/cirq_google/api/v2/sweeps.py +++ b/cirq-google/cirq_google/api/v2/sweeps.py @@ -73,6 +73,10 @@ def sweep_to_proto( out.sweep_function.function_type = run_context_pb2.SweepFunction.PRODUCT for factor in sweep.factors: sweep_to_proto(factor, out=out.sweep_function.sweeps.add()) + elif isinstance(sweep, cirq.ZipLongest): + out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP_LONGEST + for s in sweep.sweeps: + sweep_to_proto(s, out=out.sweep_function.sweeps.add()) elif isinstance(sweep, cirq.Zip): out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP for s in sweep.sweeps: @@ -129,6 +133,8 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep: return cirq.Product(*factors) if func_type == run_context_pb2.SweepFunction.ZIP: return cirq.Zip(*factors) + if func_type == run_context_pb2.SweepFunction.ZIP_LONGEST: + return cirq.ZipLongest(*factors) raise ValueError(f'invalid sweep function type: {func_type}') if which == 'single_sweep': diff --git a/cirq-google/cirq_google/api/v2/sweeps_test.py b/cirq-google/cirq_google/api/v2/sweeps_test.py index 0e16ecffe66..0535e60f5c4 100644 --- a/cirq-google/cirq_google/api/v2/sweeps_test.py +++ b/cirq-google/cirq_google/api/v2/sweeps_test.py @@ -68,6 +68,7 @@ def _values(self) -> Iterator[float]: + (cirq.Points('g', [1, 2]) * cirq.Points('h', [-1, 0, 1])) ) ), + cirq.ZipLongest(cirq.Points('a', [1.0, 2.0, 3.0]), cirq.Points('b', [1.0])), # Sweep with constant. Type ignore is because cirq.Points type annotated with floats. cirq.Points('a', [None]), # type: ignore[list-item] cirq.Points('a', [None]) * cirq.Points('b', [1, 2, 3]), # type: ignore[list-item]