Skip to content

possibly speed up state merging for CollectingState #1022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions kauldron/metrics/base_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __post_init__(self):
object.__setattr__(self, k, (val,))

@property
def _accumulated_fields(self) -> dict[str, Array]:
def _accumulated_fields(self) -> dict[str, tuple[Array, ...]]:
return {
f.name: getattr(self, f.name)
for f in dataclasses.fields(self)
Expand Down Expand Up @@ -248,16 +248,30 @@ def compute(self: _SelfT) -> _SelfT:
)


def _merge_normalize_tuple(v0, v1):
def _merge_normalize_tuple(
v0: tuple[jax.Array, ...], v1: tuple[jax.Array, ...]
) -> tuple[jax.Array, ...]:
assert isinstance(v0, tuple)
assert isinstance(v1, tuple)
values = v0 + v1
if any(isinstance(v, jax.core.Tracer) for v in values):
return _maybe_copy_to_cpu(v0) + _maybe_copy_to_cpu(v1)


def _maybe_copy_to_cpu(v: tuple[jax.Array, ...]) -> tuple[jax.Array, ...]:
"""Convert all elements of the tuple to numpy arrays."""
# If the tuple is not of length 1, that means it came from a merge (or empty),
# and we can thus skip checks and conversions.
if len(v) != 1:
return v

element = v[0]
if isinstance(element, jax.core.Tracer):
raise RuntimeError(
"Tracer detected! CollectingState.merge should not be JIT compiled."
)
# TODO(epot): Should be executed asynchronously (blocking)
return tuple(np.asarray(v) for v in values)
# start async copy to host
cpu_device = jax.devices("cpu")[0]
cpu_element = jax.device_put(element, device=cpu_device, donate=True)
return (cpu_element,)


# Inherit for better tracability/debug messages (so user can search and find
Expand Down