@@ -217,7 +217,7 @@ def __post_init__(self):
217
217
object .__setattr__ (self , k , (val ,))
218
218
219
219
@property
220
- def _accumulated_fields (self ) -> dict [str , Array ]:
220
+ def _accumulated_fields (self ) -> dict [str , tuple [ Array , ...] ]:
221
221
return {
222
222
f .name : getattr (self , f .name )
223
223
for f in dataclasses .fields (self )
@@ -248,16 +248,30 @@ def compute(self: _SelfT) -> _SelfT:
248
248
)
249
249
250
250
251
- def _merge_normalize_tuple (v0 , v1 ):
251
+ def _merge_normalize_tuple (
252
+ v0 : tuple [jax .Array , ...], v1 : tuple [jax .Array , ...]
253
+ ) -> tuple [jax .Array , ...]:
252
254
assert isinstance (v0 , tuple )
253
255
assert isinstance (v1 , tuple )
254
- values = v0 + v1
255
- if any (isinstance (v , jax .core .Tracer ) for v in values ):
256
+ return _maybe_copy_to_cpu (v0 ) + _maybe_copy_to_cpu (v1 )
257
+
258
+
259
+ def _maybe_copy_to_cpu (v : tuple [jax .Array , ...]) -> tuple [jax .Array , ...]:
260
+ """Convert all elements of the tuple to numpy arrays."""
261
+ # If the tuple is not of length 1, that means it came from a merge (or empty),
262
+ # and we can thus skip checks and conversions.
263
+ if len (v ) != 1 :
264
+ return v
265
+
266
+ element = v [0 ]
267
+ if isinstance (element , jax .core .Tracer ):
256
268
raise RuntimeError (
257
269
"Tracer detected! CollectingState.merge should not be JIT compiled."
258
270
)
259
- # TODO(epot): Should be executed asynchronously (blocking)
260
- return tuple (np .asarray (v ) for v in values )
271
+ # start async copy to host
272
+ cpu_device = jax .devices ("cpu" )[0 ]
273
+ cpu_element = jax .device_put (element , device = cpu_device , donate = True )
274
+ return (cpu_element ,)
261
275
262
276
263
277
# Inherit for better tracability/debug messages (so user can search and find
0 commit comments