File tree Expand file tree Collapse file tree 1 file changed +17
-5
lines changed Expand file tree Collapse file tree 1 file changed +17
-5
lines changed Original file line number Diff line number Diff line change @@ -248,16 +248,28 @@ def compute(self: _SelfT) -> _SelfT:
248
248
)
249
249
250
250
251
- def _merge_normalize_tuple (v0 , v1 ):
251
+ def _merge_normalize_tuple (
252
+ v0 : tuple [Any , ...],
253
+ v1 : tuple [Any , ...]
254
+ ) -> tuple [np .ndarray , ...]:
252
255
assert isinstance (v0 , tuple )
253
256
assert isinstance (v1 , tuple )
254
- values = v0 + v1
255
- if any (isinstance (v , jax .core .Tracer ) for v in values ):
257
+ return _maybe_convert_to_numpy (v0 ) + _maybe_convert_to_numpy (v1 )
258
+
259
+
260
+ def _maybe_convert_to_numpy (v : tuple [Any , ...]) -> tuple [np .ndarray , ...]:
261
+ """Convert all elements of the tuple to numpy arrays."""
262
+ # If the tuple is not of length 1, that means it came from a merge (or empty),
263
+ # and we can thus skip checks and conversions.
264
+ if len (v ) != 1 :
265
+ return v
266
+
267
+ element = v [0 ]
268
+ if isinstance (element , jax .core .Tracer ):
256
269
raise RuntimeError (
257
270
"Tracer detected! CollectingState.merge should not be JIT compiled."
258
271
)
259
- # TODO(epot): Should be executed asynchronously (blocking)
260
- return tuple (np .asarray (v ) for v in values )
272
+ return (np .asarray (element ),)
261
273
262
274
263
275
# Inherit for better tracability/debug messages (so user can search and find
You can’t perform that action at this time.
0 commit comments