Skip to content

Commit b28bfce

Browse files
authored
Speed up parameter resolution for cirq.Duration (#6270)
1 parent c7048f5 commit b28bfce

File tree

2 files changed

+70
-37
lines changed

2 files changed

+70
-37
lines changed

cirq-core/cirq/value/duration.py

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414
"""A typed time delta that supports picosecond accuracy."""
1515

16-
from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
16+
from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union, List
1717
import datetime
1818

1919
import sympy
2020
import numpy as np
2121

2222
from cirq import protocols
23-
from cirq._compat import proper_repr
23+
from cirq._compat import proper_repr, cached_method
2424
from cirq._doc import document
2525

2626
if TYPE_CHECKING:
@@ -79,48 +79,53 @@ def __init__(
7979
>>> print(cirq.Duration(micros=1.5 * sympy.Symbol('t')))
8080
(1500.0*t) ns
8181
"""
82+
self._time_vals: List[_NUMERIC_INPUT_TYPE] = [0, 0, 0, 0]
83+
self._multipliers = [1, 1000, 1000_000, 1000_000_000]
8284
if value is not None and value != 0:
8385
if isinstance(value, datetime.timedelta):
8486
# timedelta has microsecond resolution.
85-
micros += int(value / datetime.timedelta(microseconds=1))
87+
self._time_vals[2] = int(value / datetime.timedelta(microseconds=1))
8688
elif isinstance(value, Duration):
87-
picos += value._picos
89+
self._time_vals = value._time_vals
8890
else:
8991
raise TypeError(f'Not a `cirq.DURATION_LIKE`: {repr(value)}.')
90-
91-
val = picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
92-
self._picos: _NUMERIC_OUTPUT_TYPE = float(val) if isinstance(val, np.number) else val
92+
input_vals = [picos, nanos, micros, millis]
93+
self._time_vals = _add_time_vals(self._time_vals, input_vals)
9394

9495
def _is_parameterized_(self) -> bool:
95-
return protocols.is_parameterized(self._picos)
96+
return protocols.is_parameterized(self._time_vals)
9697

9798
def _parameter_names_(self) -> AbstractSet[str]:
98-
return protocols.parameter_names(self._picos)
99+
return protocols.parameter_names(self._time_vals)
99100

100101
def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Duration':
101-
return Duration(picos=protocols.resolve_parameters(self._picos, resolver, recursive))
102+
return _duration_from_time_vals(
103+
protocols.resolve_parameters(self._time_vals, resolver, recursive)
104+
)
102105

106+
@cached_method
103107
def total_picos(self) -> _NUMERIC_OUTPUT_TYPE:
104108
"""Returns the number of picoseconds that the duration spans."""
105-
return self._picos
109+
val = sum(a * b for a, b in zip(self._time_vals, self._multipliers))
110+
return float(val) if isinstance(val, np.number) else val
106111

107112
def total_nanos(self) -> _NUMERIC_OUTPUT_TYPE:
108113
"""Returns the number of nanoseconds that the duration spans."""
109-
return self._picos / 1000
114+
return self.total_picos() / 1000
110115

111116
def total_micros(self) -> _NUMERIC_OUTPUT_TYPE:
112117
"""Returns the number of microseconds that the duration spans."""
113-
return self._picos / 1000_000
118+
return self.total_picos() / 1000_000
114119

115120
def total_millis(self) -> _NUMERIC_OUTPUT_TYPE:
116121
"""Returns the number of milliseconds that the duration spans."""
117-
return self._picos / 1000_000_000
122+
return self.total_picos() / 1000_000_000
118123

119124
def __add__(self, other) -> 'Duration':
120125
other = _attempt_duration_like_to_duration(other)
121126
if other is None:
122127
return NotImplemented
123-
return Duration(picos=self._picos + other._picos)
128+
return _duration_from_time_vals(_add_time_vals(self._time_vals, other._time_vals))
124129

125130
def __radd__(self, other) -> 'Duration':
126131
return self.__add__(other)
@@ -129,86 +134,94 @@ def __sub__(self, other) -> 'Duration':
129134
other = _attempt_duration_like_to_duration(other)
130135
if other is None:
131136
return NotImplemented
132-
return Duration(picos=self._picos - other._picos)
137+
return _duration_from_time_vals(
138+
_add_time_vals(self._time_vals, [-x for x in other._time_vals])
139+
)
133140

134141
def __rsub__(self, other) -> 'Duration':
135142
other = _attempt_duration_like_to_duration(other)
136143
if other is None:
137144
return NotImplemented
138-
return Duration(picos=other._picos - self._picos)
145+
return _duration_from_time_vals(
146+
_add_time_vals(other._time_vals, [-x for x in self._time_vals])
147+
)
139148

140149
def __mul__(self, other) -> 'Duration':
141150
if not isinstance(other, (int, float, sympy.Expr)):
142151
return NotImplemented
143-
return Duration(picos=self._picos * other)
152+
if other == 0:
153+
return _duration_from_time_vals([0] * 4)
154+
return _duration_from_time_vals([x * other for x in self._time_vals])
144155

145156
def __rmul__(self, other) -> 'Duration':
146157
return self.__mul__(other)
147158

148159
def __truediv__(self, other) -> Union['Duration', float]:
149160
if isinstance(other, (int, float, sympy.Expr)):
150-
return Duration(picos=self._picos / other)
161+
new_time_vals = [x / other for x in self._time_vals]
162+
return _duration_from_time_vals(new_time_vals)
151163

152164
other_duration = _attempt_duration_like_to_duration(other)
153165
if other_duration is not None:
154-
return self._picos / other_duration._picos
166+
return self.total_picos() / other_duration.total_picos()
155167

156168
return NotImplemented
157169

158170
def __eq__(self, other):
159171
other = _attempt_duration_like_to_duration(other)
160172
if other is None:
161173
return NotImplemented
162-
return self._picos == other._picos
174+
return self.total_picos() == other.total_picos()
163175

164176
def __ne__(self, other):
165177
other = _attempt_duration_like_to_duration(other)
166178
if other is None:
167179
return NotImplemented
168-
return self._picos != other._picos
180+
return self.total_picos() != other.total_picos()
169181

170182
def __gt__(self, other):
171183
other = _attempt_duration_like_to_duration(other)
172184
if other is None:
173185
return NotImplemented
174-
return self._picos > other._picos
186+
return self.total_picos() > other.total_picos()
175187

176188
def __lt__(self, other):
177189
other = _attempt_duration_like_to_duration(other)
178190
if other is None:
179191
return NotImplemented
180-
return self._picos < other._picos
192+
return self.total_picos() < other.total_picos()
181193

182194
def __ge__(self, other):
183195
other = _attempt_duration_like_to_duration(other)
184196
if other is None:
185197
return NotImplemented
186-
return self._picos >= other._picos
198+
return self.total_picos() >= other.total_picos()
187199

188200
def __le__(self, other):
189201
other = _attempt_duration_like_to_duration(other)
190202
if other is None:
191203
return NotImplemented
192-
return self._picos <= other._picos
204+
return self.total_picos() <= other.total_picos()
193205

194206
def __bool__(self):
195-
return bool(self._picos)
207+
return bool(self.total_picos())
196208

197209
def __hash__(self):
198-
if isinstance(self._picos, (int, float)) and self._picos % 1000000 == 0:
199-
return hash(datetime.timedelta(microseconds=self._picos / 1000000))
200-
return hash((Duration, self._picos))
210+
if isinstance(self.total_picos(), (int, float)) and self.total_picos() % 1000000 == 0:
211+
return hash(datetime.timedelta(microseconds=self.total_picos() / 1000000))
212+
return hash((Duration, self.total_picos()))
201213

202214
def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]:
215+
picos = self.total_picos()
203216
if (
204-
isinstance(self._picos, sympy.Mul)
205-
and len(self._picos.args) == 2
206-
and isinstance(self._picos.args[0], (sympy.Integer, sympy.Float))
217+
isinstance(picos, sympy.Mul)
218+
and len(picos.args) == 2
219+
and isinstance(picos.args[0], (sympy.Integer, sympy.Float))
207220
):
208-
scale = self._picos.args[0]
209-
rest = self._picos.args[1]
221+
scale = picos.args[0]
222+
rest = picos.args[1]
210223
else:
211-
scale = self._picos
224+
scale = picos
212225
rest = 1
213226

214227
if scale % 1000_000_000 == 0:
@@ -234,7 +247,7 @@ def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]:
234247
return amount * rest, unit, suffix
235248

236249
def __str__(self) -> str:
237-
if self._picos == 0:
250+
if self.total_picos() == 0:
238251
return 'Duration(0)'
239252
amount, _, suffix = self._decompose_into_amount_unit_suffix()
240253
if not isinstance(amount, (int, float, sympy.Symbol)):
@@ -257,3 +270,21 @@ def _attempt_duration_like_to_duration(value: Any) -> Optional[Duration]:
257270
if isinstance(value, (int, float)) and value == 0:
258271
return Duration()
259272
return None
273+
274+
275+
def _add_time_vals(
276+
val1: List[_NUMERIC_INPUT_TYPE], val2: List[_NUMERIC_INPUT_TYPE]
277+
) -> List[_NUMERIC_INPUT_TYPE]:
278+
ret: List[_NUMERIC_INPUT_TYPE] = []
279+
for i in range(4):
280+
if val1[i] and val2[i]:
281+
ret.append(val1[i] + val2[i])
282+
else:
283+
ret.append(val1[i] or val2[i])
284+
return ret
285+
286+
287+
def _duration_from_time_vals(time_vals: List[_NUMERIC_INPUT_TYPE]):
288+
ret = Duration()
289+
ret._time_vals = time_vals
290+
return ret

cirq-core/cirq/value/duration_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,11 @@ def test_sub():
168168
def test_mul():
169169
assert Duration(picos=2) * 3 == Duration(picos=6)
170170
assert 4 * Duration(picos=3) == Duration(picos=12)
171+
assert 0 * Duration(picos=10) == Duration()
171172

172173
t = sympy.Symbol('t')
173174
assert t * Duration(picos=3) == Duration(picos=3 * t)
175+
assert 0 * Duration(picos=t) == Duration(picos=0)
174176

175177
with pytest.raises(TypeError):
176178
_ = Duration() * Duration()

0 commit comments

Comments
 (0)