Skip to content

Commit b4b4676

Browse files
QwlouseThe kauldron Authors
authored andcommitted
possibly speed up state merging for CollectingState
PiperOrigin-RevId: 726539793
1 parent 205b34e commit b4b4676

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

kauldron/metrics/base_state.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,28 @@ def compute(self: _SelfT) -> _SelfT:
248248
)
249249

250250

251-
def _merge_normalize_tuple(v0, v1):
251+
def _merge_normalize_tuple(
252+
v0: tuple[Any, ...],
253+
v1: tuple[Any, ...]
254+
) -> tuple[np.ndarray, ...]:
252255
assert isinstance(v0, tuple)
253256
assert isinstance(v1, tuple)
254-
values = v0 + v1
255-
if any(isinstance(v, jax.core.Tracer) for v in values):
257+
return _maybe_convert_to_numpy(v0) + _maybe_convert_to_numpy(v1)
258+
259+
260+
def _maybe_convert_to_numpy(v: tuple[Any, ...]) -> tuple[np.ndarray, ...]:
261+
"""Convert all elements of the tuple to numpy arrays."""
262+
# If the tuple is not of length 1, that means it came from a merge (or empty),
263+
# and we can thus skip checks and conversions.
264+
if len(v) != 1:
265+
return v
266+
267+
element = v[0]
268+
if isinstance(element, jax.core.Tracer):
256269
raise RuntimeError(
257270
"Tracer detected! CollectingState.merge should not be JIT compiled."
258271
)
259-
# TODO(epot): Should be executed asynchronously (blocking)
260-
return tuple(np.asarray(v) for v in values)
272+
return (np.asarray(element),)
261273

262274

263275
# Inherit for better tracability/debug messages (so user can search and find

0 commit comments

Comments
 (0)