-
Notifications
You must be signed in to change notification settings - Fork 11
ENH: at
support for bool mask in Dask and JAX
#121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,12 @@ | |
from types import ModuleType | ||
from typing import ClassVar, cast | ||
|
||
from ._utils._compat import array_namespace, is_jax_array, is_writeable_array | ||
from ._utils._compat import ( | ||
array_namespace, | ||
is_dask_array, | ||
is_jax_array, | ||
is_writeable_array, | ||
) | ||
from ._utils._typing import Array, Index | ||
|
||
|
||
|
@@ -141,6 +146,25 @@ | |
not explicitly covered by ``array-api-compat``, are not supported by update | ||
methods. | ||
|
||
Boolean masks are supported on Dask and jitted JAX arrays exclusively | ||
when `idx` has the same shape as `x` and `y` is 0-dimensional. | ||
Note that this is support is not available in JAX's native | ||
``x.at[mask].set(y)``. | ||
|
||
This pattern:: | ||
|
||
>>> mask = m(x) | ||
>>> x[mask] = f(x[mask]) | ||
|
||
Can't be replaced by `at`, as it won't work on Dask and JAX inside jax.jit:: | ||
|
||
>>> mask = m(x) | ||
>>> x = xpx.at(x, mask).set(f(x[mask]) # Crash on Dask and jax.jit | ||
|
||
You should instead use:: | ||
|
||
>>> x = xp.where(m(x), f(x), x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO suggest to use lazywhere |
||
|
||
Examples | ||
-------- | ||
Given either of these equivalent expressions:: | ||
|
@@ -189,6 +213,7 @@ | |
self, | ||
at_op: _AtOp, | ||
in_place_op: Callable[[Array, Array | object], Array] | None, | ||
out_of_place_op: Callable[[Array, Array], Array] | None, | ||
y: Array | object, | ||
/, | ||
copy: bool | None, | ||
|
@@ -210,6 +235,16 @@ | |
|
||
x[idx] = y | ||
|
||
out_of_place_op : Callable[[Array, Array], Array] | None | ||
Out-of-place operation to apply when idx is a boolean mask and the backend | ||
doesn't support in-place updates:: | ||
|
||
x = xp.where(idx, out_of_place_op(x, y), x) | ||
|
||
If None:: | ||
|
||
x = xp.where(idx, y, x) | ||
|
||
y : array or object | ||
Right-hand side of the operation. | ||
copy : bool or None | ||
|
@@ -223,6 +258,7 @@ | |
Updated `x`. | ||
""" | ||
x, idx = self._x, self._idx | ||
xp = array_namespace(x, y) if xp is None else xp | ||
|
||
if idx is _undef: | ||
msg = ( | ||
|
@@ -247,15 +283,41 @@ | |
else: | ||
writeable = is_writeable_array(x) | ||
|
||
# JAX inside jax.jit and Dask don't support in-place updates with boolean | ||
# mask. However we can handle the common special case of 0-dimensional y | ||
# with where(idx, y, x) instead. | ||
if ( | ||
(is_dask_array(idx) or is_jax_array(idx)) | ||
and idx.dtype == xp.bool | ||
and idx.shape == x.shape | ||
): | ||
y_xp = xp.asarray(y, dtype=x.dtype) | ||
if y_xp.ndim == 0: | ||
if out_of_place_op: | ||
# FIXME: suppress inf warnings on dask with lazywhere | ||
out = xp.where(idx, out_of_place_op(x, y_xp), x) | ||
# Undo int->float promotion on JAX after _AtOp.DIVIDE | ||
out = xp.astype(out, x.dtype, copy=False) | ||
else: | ||
out = xp.where(idx, y_xp, x) | ||
|
||
if copy: | ||
return out | ||
x[()] = out | ||
return x | ||
# else: this will work on eager JAX and crash on jax.jit and Dask | ||
|
||
if copy: | ||
if is_jax_array(x): | ||
# Use JAX's at[] | ||
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value)) | ||
return func(y) | ||
out = func(y) | ||
# Undo int->float promotion on JAX after _AtOp.DIVIDE | ||
return xp.astype(out, x.dtype, copy=False) | ||
|
||
# Emulate at[] behaviour for non-JAX arrays | ||
# with a copy followed by an update | ||
if xp is None: | ||
xp = array_namespace(x) | ||
|
||
x = xp.asarray(x, copy=True) | ||
if writeable is False: | ||
# A copy of a read-only numpy array is writeable | ||
|
@@ -283,7 +345,7 @@ | |
xp: ModuleType | None = None, | ||
) -> Array: # numpydoc ignore=PR01,RT01 | ||
"""Apply ``x[idx] = y`` and return the update array.""" | ||
return self._op(_AtOp.SET, None, y, copy=copy, xp=xp) | ||
return self._op(_AtOp.SET, None, None, y, copy=copy, xp=xp) | ||
|
||
def add( | ||
self, | ||
|
@@ -297,7 +359,7 @@ | |
# Note for this and all other methods based on _iop: | ||
# operator.iadd and operator.add subtly differ in behaviour, as | ||
# only iadd will trigger exceptions when y has an incompatible dtype. | ||
return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp) | ||
return self._op(_AtOp.ADD, operator.iadd, operator.add, y, copy=copy, xp=xp) | ||
|
||
def subtract( | ||
self, | ||
|
@@ -307,7 +369,9 @@ | |
xp: ModuleType | None = None, | ||
) -> Array: # numpydoc ignore=PR01,RT01 | ||
"""Apply ``x[idx] -= y`` and return the updated array.""" | ||
return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp) | ||
return self._op( | ||
_AtOp.SUBTRACT, operator.isub, operator.sub, y, copy=copy, xp=xp | ||
) | ||
|
||
def multiply( | ||
self, | ||
|
@@ -317,7 +381,9 @@ | |
xp: ModuleType | None = None, | ||
) -> Array: # numpydoc ignore=PR01,RT01 | ||
"""Apply ``x[idx] *= y`` and return the updated array.""" | ||
return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp) | ||
return self._op( | ||
_AtOp.MULTIPLY, operator.imul, operator.mul, y, copy=copy, xp=xp | ||
) | ||
|
||
def divide( | ||
self, | ||
|
@@ -327,7 +393,9 @@ | |
xp: ModuleType | None = None, | ||
) -> Array: # numpydoc ignore=PR01,RT01 | ||
"""Apply ``x[idx] /= y`` and return the updated array.""" | ||
return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp) | ||
return self._op( | ||
_AtOp.DIVIDE, operator.itruediv, operator.truediv, y, copy=copy, xp=xp | ||
) | ||
|
||
def power( | ||
self, | ||
|
@@ -337,7 +405,7 @@ | |
xp: ModuleType | None = None, | ||
) -> Array: # numpydoc ignore=PR01,RT01 | ||
"""Apply ``x[idx] **= y`` and return the updated array.""" | ||
return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp) | ||
return self._op(_AtOp.POWER, operator.ipow, operator.pow, y, copy=copy, xp=xp) | ||
|
||
def min( | ||
self, | ||
|
@@ -349,7 +417,7 @@ | |
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array.""" | ||
xp = array_namespace(self._x) if xp is None else xp | ||
y = xp.asarray(y) | ||
return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp) | ||
return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp) | ||
|
||
def max( | ||
self, | ||
|
@@ -361,4 +429,4 @@ | |
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array.""" | ||
xp = array_namespace(self._x) if xp is None else xp | ||
y = xp.asarray(y) | ||
return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp) | ||
return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.