12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Any , cast , Dict , List , Optional
15
+ from typing import Any , cast , Callable , Dict , List , Optional
16
16
17
17
import sympy
18
18
import tunits
19
19
20
20
import cirq
21
+ from cirq .study import sweeps
21
22
from cirq_google .api .v2 import run_context_pb2
22
23
from cirq_google .study .device_parameter import DeviceParameter
23
24
@@ -55,14 +56,18 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any:
55
56
56
57
57
58
def sweep_to_proto (
58
- sweep : cirq .Sweep , * , out : Optional [run_context_pb2 .Sweep ] = None
59
+ sweep : cirq .Sweep ,
60
+ * ,
61
+ out : Optional [run_context_pb2 .Sweep ] = None ,
62
+ sweep_transformer : Callable [[sweeps .SingleSweep ], sweeps .SingleSweep ] = lambda x : x ,
59
63
) -> run_context_pb2 .Sweep :
60
64
"""Converts a Sweep to v2 protobuf message.
61
65
62
66
Args:
63
67
sweep: The sweep to convert.
64
68
out: Optional message to be populated. If not given, a new message will
65
69
be created.
70
+ sweep_transformer: A function called on Linspace, Points.
66
71
67
72
Returns:
68
73
Populated sweep protobuf message.
@@ -91,6 +96,7 @@ def sweep_to_proto(
91
96
for s in sweep .sweeps :
92
97
sweep_to_proto (s , out = out .sweep_function .sweeps .add ())
93
98
elif isinstance (sweep , cirq .Linspace ) and not isinstance (sweep .key , sympy .Expr ):
99
+ sweep = cast (cirq .Linspace , sweep_transformer (sweep ))
94
100
out .single_sweep .parameter_key = sweep .key
95
101
if isinstance (sweep .start , tunits .Value ):
96
102
unit = sweep .start .unit
@@ -110,6 +116,7 @@ def sweep_to_proto(
110
116
if sweep .metadata and getattr (sweep .metadata , 'units' , None ):
111
117
out .single_sweep .parameter .units = sweep .metadata .units
112
118
elif isinstance (sweep , cirq .Points ) and not isinstance (sweep .key , sympy .Expr ):
119
+ sweep = cast (cirq .Points , sweep_transformer (sweep ))
113
120
out .single_sweep .parameter_key = sweep .key
114
121
if len (sweep .points ) == 1 :
115
122
out .single_sweep .const_value .MergeFrom (_build_sweep_const (sweep .points [0 ]))
@@ -142,8 +149,17 @@ def sweep_to_proto(
142
149
return out
143
150
144
151
145
- def sweep_from_proto (msg : run_context_pb2 .Sweep ) -> cirq .Sweep :
146
- """Creates a Sweep from a v2 protobuf message."""
152
+ def sweep_from_proto (
153
+ msg : run_context_pb2 .Sweep ,
154
+ sweep_transformer : Callable [[sweeps .SingleSweep ], sweeps .SingleSweep ] = lambda x : x ,
155
+ ) -> cirq .Sweep :
156
+ """Creates a Sweep from a v2 protobuf message.
157
+
158
+ Args:
159
+ msg: Serialized sweep message.
160
+ sweep_transformer: A function called on Linspace, Point, and ConstValue.
161
+
162
+ """
147
163
which = msg .WhichOneof ('sweep' )
148
164
if which is None :
149
165
return cirq .UnitSweep
@@ -178,31 +194,38 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
178
194
)
179
195
else :
180
196
metadata = None
197
+
181
198
if msg .single_sweep .WhichOneof ('sweep' ) == 'linspace' :
182
199
unit : float | tunits .Value = 1.0
183
200
if msg .single_sweep .linspace .HasField ('unit' ):
184
201
unit = tunits .Value .from_proto (msg .single_sweep .linspace .unit )
185
- return cirq .Linspace (
186
- key = key ,
187
- start = msg .single_sweep .linspace .first_point * unit , # type: ignore[arg-type]
188
- stop = msg .single_sweep .linspace .last_point * unit , # type: ignore[arg-type]
189
- length = msg .single_sweep .linspace .num_points ,
190
- metadata = metadata ,
202
+ return sweep_transformer (
203
+ cirq .Linspace (
204
+ key = key ,
205
+ start = msg .single_sweep .linspace .first_point * unit , # type: ignore[arg-type]
206
+ stop = msg .single_sweep .linspace .last_point * unit , # type: ignore[arg-type]
207
+ length = msg .single_sweep .linspace .num_points ,
208
+ metadata = metadata ,
209
+ )
191
210
)
192
211
if msg .single_sweep .WhichOneof ('sweep' ) == 'points' :
193
212
unit = 1.0
194
213
if msg .single_sweep .points .HasField ('unit' ):
195
214
unit = tunits .Value .from_proto (msg .single_sweep .points .unit )
196
- return cirq .Points (
197
- key = key ,
198
- points = [p * unit for p in msg .single_sweep .points .points ],
199
- metadata = metadata ,
215
+ return sweep_transformer (
216
+ cirq .Points (
217
+ key = key ,
218
+ points = [p * unit for p in msg .single_sweep .points .points ],
219
+ metadata = metadata ,
220
+ )
200
221
)
201
222
if msg .single_sweep .WhichOneof ('sweep' ) == 'const_value' :
202
- return cirq .Points (
203
- key = key ,
204
- points = [_recover_sweep_const (msg .single_sweep .const_value )],
205
- metadata = metadata ,
223
+ return sweep_transformer (
224
+ cirq .Points (
225
+ key = key ,
226
+ points = [_recover_sweep_const (msg .single_sweep .const_value )],
227
+ metadata = metadata ,
228
+ )
206
229
)
207
230
208
231
raise ValueError (f'single sweep type not set: { msg } ' )
0 commit comments