Skip to content

Commit f596c43

Browse files
authored
Make qubits a memoized property of Moment (#6894)
Moment already has a _qubit_to_op dict. Maintaining a separate _qubits frozenset is inefficient when constructing circuits, as it needs copied each time an op is added to a moment. In particular, for wide moments, this change seems to result in about a 3x speedup. On my laptop, creating a moment with X gates on 10_000 qubits takes 4s before this change, and 1.3s after.
1 parent 7f66b42 commit f596c43

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

cirq-core/cirq/circuits/moment.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""A simplified time-slice of operations within a sequenced circuit."""
1616

1717
import itertools
18+
from functools import cached_property
1819
from types import NotImplementedType
1920
from typing import (
2021
AbstractSet,
@@ -113,7 +114,6 @@ def __init__(self, *contents: 'cirq.OP_TREE', _flatten_contents: bool = True) ->
113114
raise ValueError(f'Overlapping operations: {self.operations}')
114115
self._qubit_to_op[q] = op
115116

116-
self._qubits = frozenset(self._qubit_to_op.keys())
117117
self._measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None
118118
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None
119119

@@ -135,9 +135,9 @@ def from_ops(cls, *ops: 'cirq.Operation') -> 'cirq.Moment':
135135
def operations(self) -> Tuple['cirq.Operation', ...]:
136136
return self._operations
137137

138-
@property
138+
@cached_property
139139
def qubits(self) -> FrozenSet['cirq.Qid']:
140-
return self._qubits
140+
return frozenset(self._qubit_to_op)
141141

142142
def operates_on_single_qubit(self, qubit: 'cirq.Qid') -> bool:
143143
"""Determines if the moment has operations touching the given qubit.
@@ -157,7 +157,7 @@ def operates_on(self, qubits: Iterable['cirq.Qid']) -> bool:
157157
Returns:
158158
Whether this moment has operations involving the qubits.
159159
"""
160-
return not self._qubits.isdisjoint(qubits)
160+
return not self._qubit_to_op.keys().isdisjoint(qubits)
161161

162162
def operation_at(self, qubit: raw_types.Qid) -> Optional['cirq.Operation']:
163163
"""Returns the operation on a certain qubit for the moment.
@@ -185,14 +185,13 @@ def with_operation(self, operation: 'cirq.Operation') -> 'cirq.Moment':
185185
Raises:
186186
ValueError: If the operation given overlaps a current operation in the moment.
187187
"""
188-
if any(q in self._qubits for q in operation.qubits):
188+
if any(q in self._qubit_to_op for q in operation.qubits):
189189
raise ValueError(f'Overlapping operations: {operation}')
190190

191191
# Use private variables to facilitate a quick copy.
192192
m = Moment(_flatten_contents=False)
193193
m._operations = self._operations + (operation,)
194194
m._sorted_operations = None
195-
m._qubits = self._qubits.union(operation.qubits)
196195
m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}}
197196

198197
m._measurement_key_objs = self._measurement_key_objs_().union(
@@ -222,14 +221,11 @@ def with_operations(self, *contents: 'cirq.OP_TREE') -> 'cirq.Moment':
222221
m = Moment(_flatten_contents=False)
223222
# Use private variables to facilitate a quick copy.
224223
m._qubit_to_op = self._qubit_to_op.copy()
225-
qubits = set(self._qubits)
226224
for op in flattened_contents:
227-
if any(q in qubits for q in op.qubits):
225+
if any(q in m._qubit_to_op for q in op.qubits):
228226
raise ValueError(f'Overlapping operations: {op}')
229-
qubits.update(op.qubits)
230227
for q in op.qubits:
231228
m._qubit_to_op[q] = op
232-
m._qubits = frozenset(qubits)
233229

234230
m._operations = self._operations + flattened_contents
235231
m._sorted_operations = None
@@ -450,7 +446,9 @@ def expand_to(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Moment':
450446
@_compat.cached_method()
451447
def _has_kraus_(self) -> bool:
452448
"""Returns True if self has a Kraus representation and self uses <= 10 qubits."""
453-
return all(protocols.has_kraus(op) for op in self.operations) and len(self.qubits) <= 10
449+
return (
450+
all(protocols.has_kraus(op) for op in self.operations) and len(self._qubit_to_op) <= 10
451+
)
454452

455453
def _kraus_(self) -> Sequence[np.ndarray]:
456454
r"""Returns Kraus representation of self.
@@ -475,7 +473,7 @@ def _kraus_(self) -> Sequence[np.ndarray]:
475473
if not self._has_kraus_():
476474
return NotImplemented
477475

478-
qubits = sorted(self.qubits)
476+
qubits = sorted(self._qubit_to_op)
479477
n = len(qubits)
480478
if n < 1:
481479
return (np.array([[1 + 0j]]),)
@@ -602,7 +600,7 @@ def to_text_diagram(
602600
"""
603601

604602
# Figure out where to place everything.
605-
qs = set(self.qubits) | set(extra_qubits)
603+
qs = self._qubit_to_op.keys() | set(extra_qubits)
606604
points = {xy_breakdown_func(q) for q in qs}
607605
x_keys = sorted({pt[0] for pt in points}, key=_SortByValFallbackToType)
608606
y_keys = sorted({pt[1] for pt in points}, key=_SortByValFallbackToType)

0 commit comments

Comments
 (0)