Skip to content

Commit 48ca2ea

Browse files
QwlouseThe kauldron Authors
authored andcommitted
possibly speed up state merging for CollectingState
PiperOrigin-RevId: 726539793
1 parent 205b34e commit 48ca2ea

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

kauldron/metrics/base_state.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def __post_init__(self):
217217
object.__setattr__(self, k, (val,))
218218

219219
@property
220-
def _accumulated_fields(self) -> dict[str, Array]:
220+
def _accumulated_fields(self) -> dict[str, tuple[Array, ...]]:
221221
return {
222222
f.name: getattr(self, f.name)
223223
for f in dataclasses.fields(self)
@@ -248,16 +248,30 @@ def compute(self: _SelfT) -> _SelfT:
248248
)
249249

250250

251-
def _merge_normalize_tuple(v0, v1):
251+
def _merge_normalize_tuple(
252+
v0: tuple[jax.Array, ...], v1: tuple[jax.Array, ...]
253+
) -> tuple[jax.Array, ...]:
252254
assert isinstance(v0, tuple)
253255
assert isinstance(v1, tuple)
254-
values = v0 + v1
255-
if any(isinstance(v, jax.core.Tracer) for v in values):
256+
return _maybe_copy_to_cpu(v0) + _maybe_copy_to_cpu(v1)
257+
258+
259+
def _maybe_copy_to_cpu(v: tuple[jax.Array, ...]) -> tuple[jax.Array, ...]:
260+
"""Convert all elements of the tuple to numpy arrays."""
261+
# If the tuple is not of length 1, that means it came from a merge (or empty),
262+
# and we can thus skip checks and conversions.
263+
if len(v) != 1:
264+
return v
265+
266+
element = v[0]
267+
if isinstance(element, jax.core.Tracer):
256268
raise RuntimeError(
257269
"Tracer detected! CollectingState.merge should not be JIT compiled."
258270
)
259-
# TODO(epot): Should be executed asynchronously (blocking)
260-
return tuple(np.asarray(v) for v in values)
271+
# start async copy to host
272+
cpu_device = jax.devices("cpu")[0]
273+
cpu_element = jax.device_put(element, device=cpu_device, donate=True)
274+
return (cpu_element,)
261275

262276

263277
# Inherit for better tracability/debug messages (so user can search and find

0 commit comments

Comments
 (0)