|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import math
|
| 15 | +from copy import deepcopy |
15 | 16 | from typing import Iterator
|
16 | 17 |
|
17 | 18 | import pytest
|
@@ -275,6 +276,64 @@ def strip_tunit_func(sweep: sweeps.SingleSweep):
|
275 | 276 | assert list(sweep.points)[0] == 1.0
|
276 | 277 |
|
277 | 278 |
|
| 279 | +@pytest.mark.parametrize( |
| 280 | + 'sweep', |
| 281 | + [ |
| 282 | + cirq.Concat(cirq.Points('a', [1, 2, 3]), cirq.Points('a', [4])), |
| 283 | + cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6]), |
| 284 | + cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [1])), |
| 285 | + cirq.Zip(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6])), |
| 286 | + ], |
| 287 | +) |
| 288 | +def test_sweep_to_proto_with_func_on_resursive_sweep_succeeds(sweep): |
| 289 | + def add_tunit_func(sweep: sweeps.SingleSweep): |
| 290 | + if isinstance(sweep, cirq.Points): |
| 291 | + sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc] |
| 292 | + |
| 293 | + return sweep |
| 294 | + |
| 295 | + msg = v2.sweep_to_proto(sweep, sweep_transformer=add_tunit_func) |
| 296 | + |
| 297 | + assert msg.sweep_function.sweeps[0].single_sweep.points.unit == tunits.ns.to_proto() |
| 298 | + |
| 299 | + |
| 300 | +@pytest.mark.parametrize( |
| 301 | + 'expected_sweep', |
| 302 | + [ |
| 303 | + cirq.Concat(cirq.Points('a', [1.0, 2.0, 3.0]), cirq.Points('a', [4.0])), |
| 304 | + cirq.Points('a', [1.0, 2.0, 3.0]) * cirq.Points('b', [4.0, 5.0, 6.0]), |
| 305 | + cirq.ZipLongest(cirq.Points('a', [1.0, 2.0, 3.0]), cirq.Points('b', [1.0])), |
| 306 | + cirq.Zip(cirq.Points('a', [1.0, 2.0, 3.0]), cirq.Points('b', [4.0, 5.0, 6.0])), |
| 307 | + cirq.Points('a', [1, 2, 3]) |
| 308 | + + cirq.Points( |
| 309 | + 'b', |
| 310 | + [4, 5, 6], |
| 311 | + metadata=DeviceParameter(path=['path', 'to', 'parameter'], idx=2, units='GHz'), |
| 312 | + ), |
| 313 | + ], |
| 314 | +) |
| 315 | +def test_sweep_from_proto_with_func_on_resursive_sweep_succeeds(expected_sweep): |
| 316 | + def add_tunit_func(sweep_to_transform: sweeps.SingleSweep): |
| 317 | + sweep = deepcopy(sweep_to_transform) |
| 318 | + if isinstance(sweep, cirq.Points): |
| 319 | + sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc] |
| 320 | + |
| 321 | + return sweep |
| 322 | + |
| 323 | + def strip_tunit_func(sweep_to_transform: sweeps.SingleSweep): |
| 324 | + sweep = deepcopy(sweep_to_transform) |
| 325 | + if isinstance(sweep, cirq.Points): |
| 326 | + if isinstance(sweep.points[0], tunits.Value): |
| 327 | + sweep.points = [point[point.unit] for point in sweep.points] |
| 328 | + |
| 329 | + return sweep |
| 330 | + |
| 331 | + msg = v2.sweep_to_proto(expected_sweep, sweep_transformer=add_tunit_func) |
| 332 | + round_trip_sweep = v2.sweep_from_proto(msg, strip_tunit_func) |
| 333 | + |
| 334 | + assert round_trip_sweep == expected_sweep |
| 335 | + |
| 336 | + |
278 | 337 | def test_sweep_with_list_sweep():
|
279 | 338 | ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
|
280 | 339 | proto = v2.sweep_to_proto(ls)
|
|
0 commit comments