Skip to content

Commit def1b08

Browse files
authored
Add Concat to Sweeps (#6819)
* add Concat to sweeps * json repr * json testing * add test for empty sweep * comments * coverage for keys property
1 parent 35733fd commit def1b08

File tree

7 files changed

+184
-0
lines changed

7 files changed

+184
-0
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@
494494
)
495495

496496
from cirq.study import (
497+
Concat as Concat,
497498
dict_to_product_sweep as dict_to_product_sweep,
498499
dict_to_zip_sweep as dict_to_zip_sweep,
499500
ExpressionMap as ExpressionMap,

cirq-core/cirq/json_resolver_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _symmetricalqidpair(qids):
120120
'CliffordState': cirq.CliffordState,
121121
'CliffordTableau': cirq.CliffordTableau,
122122
'CNotPowGate': cirq.CNotPowGate,
123+
'Concat': cirq.Concat,
123124
'ConstantQubitNoiseModel': cirq.ConstantQubitNoiseModel,
124125
'ControlledGate': cirq.ControlledGate,
125126
'ControlledOperation': cirq.ControlledOperation,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"cirq_type": "Concat",
3+
"sweeps": [
4+
{
5+
"cirq_type": "Linspace",
6+
"key": "a",
7+
"start": 0,
8+
"stop": 1,
9+
"length": 2
10+
},
11+
{
12+
"cirq_type": "Linspace",
13+
"key": "a",
14+
"start": 0,
15+
"stop": 2,
16+
"length": 4
17+
}
18+
]
19+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cirq.Concat(cirq.Linspace('a', start=0, stop=1, length=2), cirq.Linspace('a', start=0, stop=2, length=4))

cirq-core/cirq/study/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737

3838
from cirq.study.sweeps import (
39+
Concat as Concat,
3940
Linspace as Linspace,
4041
ListSweep as ListSweep,
4142
Points as Points,

cirq-core/cirq/study/sweeps.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,63 @@ def _from_json_dict_(cls, factors, **kwargs):
276276
return Product(*factors)
277277

278278

279+
class Concat(Sweep):
280+
"""Concatenates multiple to a new sweep.
281+
282+
All sweeps must share the same descriptors.
283+
284+
If one sweep assigns 'a' to the values 0, 1, 2, and another sweep assigns
285+
'a' to the values 3, 4, 5, the concatenation produces a sweep assigning
286+
'a' to the values 0, 1, 2, 3, 4, 5 in sequence.
287+
"""
288+
289+
def __init__(self, *sweeps: Sweep) -> None:
290+
if not sweeps:
291+
raise ValueError("Concat requires at least one sweep.")
292+
293+
# Validate consistency across sweeps
294+
first_sweep = sweeps[0]
295+
for sweep in sweeps[1:]:
296+
if sweep.keys != first_sweep.keys:
297+
raise ValueError("All sweeps must have the same descriptors.")
298+
299+
self.sweeps = sweeps
300+
301+
def __eq__(self, other):
302+
if not isinstance(other, Concat):
303+
return NotImplemented
304+
return self.sweeps == other.sweeps
305+
306+
def __hash__(self):
307+
return hash(tuple(self.sweeps))
308+
309+
@property
310+
def keys(self) -> List['cirq.TParamKey']:
311+
return self.sweeps[0].keys
312+
313+
def __len__(self) -> int:
314+
return sum(len(sweep) for sweep in self.sweeps)
315+
316+
def param_tuples(self) -> Iterator[Params]:
317+
for sweep in self.sweeps:
318+
yield from sweep.param_tuples()
319+
320+
def __repr__(self) -> str:
321+
sweeps_repr = ', '.join(repr(sweep) for sweep in self.sweeps)
322+
return f'cirq.Concat({sweeps_repr})'
323+
324+
def __str__(self) -> str:
325+
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
326+
return f'Concat({sweeps_repr})'
327+
328+
def _json_dict_(self) -> Dict[str, Any]:
329+
return protocols.obj_to_dict_helper(self, ['sweeps'])
330+
331+
@classmethod
332+
def _from_json_dict_(cls, sweeps, **kwargs):
333+
return Concat(*sweeps)
334+
335+
279336
class Zip(Sweep):
280337
"""Zip product (direct sum) of one or more sweeps.
281338

cirq-core/cirq/study/sweeps_test.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ def test_equality():
246246
et.make_equality_group(lambda: cirq.Linspace('b', 0, 10, 11))
247247
et.make_equality_group(lambda: cirq.Points('a', list(range(11))))
248248
et.make_equality_group(lambda: cirq.Points('b', list(range(11))))
249+
et.make_equality_group(lambda: cirq.Concat(cirq.Linspace('a', 0, 10, 11)))
250+
et.make_equality_group(lambda: cirq.Concat(cirq.Linspace('b', 0, 10, 11)))
249251

250252
# Product and Zip sweeps can also be equated.
251253
et.make_equality_group(lambda: cirq.Linspace('a', 0, 5, 6) * cirq.Linspace('b', 10, 15, 6))
@@ -373,3 +375,105 @@ def test_dict_to_zip_sweep():
373375
assert cirq.dict_to_zip_sweep({'t': [0, 1], 's': [2, 3], 'r': 4}) == (
374376
cirq.Zip(cirq.Points('t', [0, 1]), cirq.Points('s', [2, 3]), cirq.Points('r', [4]))
375377
)
378+
379+
380+
def test_concat_linspace():
381+
sweep1 = cirq.Linspace('a', 0.34, 9.16, 4)
382+
sweep2 = cirq.Linspace('a', 10, 20, 4)
383+
concat_sweep = cirq.Concat(sweep1, sweep2)
384+
385+
assert len(concat_sweep) == 8
386+
assert concat_sweep.keys == ['a']
387+
params = list(concat_sweep.param_tuples())
388+
assert len(params) == 8
389+
assert params[0] == (('a', 0.34),)
390+
assert params[3] == (('a', 9.16),)
391+
assert params[4] == (('a', 10.0),)
392+
assert params[7] == (('a', 20.0),)
393+
394+
395+
def test_concat_points():
396+
sweep1 = cirq.Points('a', [1, 2])
397+
sweep2 = cirq.Points('a', [3, 4, 5])
398+
concat_sweep = cirq.Concat(sweep1, sweep2)
399+
400+
assert concat_sweep.keys == ['a']
401+
assert len(concat_sweep) == 5
402+
params = list(concat_sweep)
403+
assert len(params) == 5
404+
assert _values(concat_sweep, 'a') == [1, 2, 3, 4, 5]
405+
406+
407+
def test_concat_many_points():
408+
sweep1 = cirq.Points('a', [1, 2])
409+
sweep2 = cirq.Points('a', [3, 4, 5])
410+
sweep3 = cirq.Points('a', [6, 7, 8])
411+
concat_sweep = cirq.Concat(sweep1, sweep2, sweep3)
412+
413+
assert len(concat_sweep) == 8
414+
params = list(concat_sweep)
415+
assert len(params) == 8
416+
assert _values(concat_sweep, 'a') == [1, 2, 3, 4, 5, 6, 7, 8]
417+
418+
419+
def test_concat_mixed():
420+
sweep1 = cirq.Linspace('a', 0, 1, 3)
421+
sweep2 = cirq.Points('a', [2, 3])
422+
concat_sweep = cirq.Concat(sweep1, sweep2)
423+
424+
assert len(concat_sweep) == 5
425+
assert _values(concat_sweep, 'a') == [0.0, 0.5, 1.0, 2, 3]
426+
427+
428+
def test_concat_inconsistent_keys():
429+
sweep1 = cirq.Linspace('a', 0, 1, 3)
430+
sweep2 = cirq.Points('b', [2, 3])
431+
432+
with pytest.raises(ValueError, match="All sweeps must have the same descriptors"):
433+
cirq.Concat(sweep1, sweep2)
434+
435+
436+
def test_concat_sympy_symbol():
437+
a = sympy.Symbol('a')
438+
sweep1 = cirq.Linspace(a, 0, 1, 3)
439+
sweep2 = cirq.Points(a, [2, 3])
440+
concat_sweep = cirq.Concat(sweep1, sweep2)
441+
442+
assert len(concat_sweep) == 5
443+
assert _values(concat_sweep, 'a') == [0.0, 0.5, 1.0, 2, 3]
444+
445+
446+
def test_concat_repr_and_str():
447+
sweep1 = cirq.Linspace('a', 0, 1, 3)
448+
sweep2 = cirq.Points('a', [2, 3])
449+
concat_sweep = cirq.Concat(sweep1, sweep2)
450+
451+
expected_repr = (
452+
"cirq.Concat(cirq.Linspace('a', start=0, stop=1, length=3), cirq.Points('a', [2, 3]))"
453+
)
454+
expected_str = "Concat(cirq.Linspace('a', start=0, stop=1, length=3), cirq.Points('a', [2, 3]))"
455+
456+
assert repr(concat_sweep) == expected_repr
457+
assert str(concat_sweep) == expected_str
458+
459+
460+
def test_concat_large_sweep():
461+
sweep1 = cirq.Points('a', list(range(101)))
462+
sweep2 = cirq.Points('a', list(range(101, 202)))
463+
concat_sweep = cirq.Concat(sweep1, sweep2)
464+
465+
assert len(concat_sweep) == 202
466+
assert _values(concat_sweep, 'a') == list(range(101)) + list(range(101, 202))
467+
468+
469+
def test_concat_different_keys_raises():
470+
sweep1 = cirq.Linspace('a', 0, 1, 3)
471+
sweep2 = cirq.Points('b', [2, 3])
472+
473+
with pytest.raises(ValueError, match="All sweeps must have the same descriptors."):
474+
_ = cirq.Concat(sweep1, sweep2)
475+
476+
477+
def test_concat_empty_sweep_raises():
478+
with pytest.raises(ValueError, match="Concat requires at least one sweep."):
479+
_ = cirq.Concat()

0 commit comments

Comments
 (0)