Description
Proposal
Add APIs for getting and setting elements via a list of indices.
Motivation
Currently, the array API specification does not provide a direct means of extracting and setting a list of elements along an axis. Such operations are relatively common in NumPy usage either via "fancy indexing" or via explicit take
and put
APIs.
Two main arguments come to mind for supporting at least basic take
and put
APIs:
-
Indexing does not currently support providing a list of indices to index into an array. The principal reason for not supporting fancy indexing stems from dynamic shapes and compatibility with accelerator libraries. However, use of fancy indexing is relatively common in NumPy and similar libraries where dynamically extracting rows/cols/values is possible and can be readily implemented.
-
Currently, the output of a subset of APIs currently included in the standard cannot be readily consumed without manual workarounds if a specification-conforming library implemented only the APIs in the standard. For example,
Background
The following table summarizes library implementations of such APIs:
op | NumPy | CuPy | Dask | MXNet | Torch | TensorFlow |
---|---|---|---|---|---|---|
extracting elements along axis | take |
take |
take |
take |
take /gather |
gather /numpy.take |
setting elements along axis | put |
put |
-- | -- | scatter |
scatter_nd /tensor_scatter_nd_update |
extracting elements over matching 1d slices | take_along_axis |
take_along_axis |
-- | -- | -- | gather_nd /numpy.take_alongaxis |
setting elements over matching 1d slices | put_along_axis |
-- | -- | -- | -- | -- |
While most libraries implement some form of take
, fewer implement other complementary APIs.
Activity
rgommers commentedon May 6, 2021
Thanks @kgryte. A few initial thoughts:
take
/put
and variousscatter
/gather
functions in TensorFlow, PyTorch and MXNet. There's a whole host of those functions.indices.size
? Andput
is an inplace operation which doesn't change the shape.kgryte commentedon May 6, 2021
@rgommers Thanks for the comments.
indices.size
need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.asmeurer commentedon May 6, 2021
A natural question is if
take
is supported, is there any reason equivalent indexing shouldn't also be supported. Granted,take
only represents a specific subset of general (NumPy) integer array indexing, where indexing is done on a single axis.kgryte commentedon May 6, 2021
@asmeurer I think that
take
would be an optional API; whereas indexing semantics should be universal.rgommers commentedon May 7, 2021
If the size of
indices
is variable, it's the function that producesindices
that is data-dependent.take
itself however is not. Compare with boolean indexing ornonzero
, there the output size is in the range[0, x_input.size]
; fortake
it's alwaysx_input.size
.kgryte commentedon May 11, 2021
@rgommers Correct; however, I still could imagine that data flows involving a
take
operation may still be problematic for AOT computational graphs. While the output size isindices.size
, an array library may not be able to statically allocate memory for the output of thetake
operation. This said, accelerator libraries do manage to support similar APIs (e.g.,scatter
/gather
), so probably no need to further belabor this.kgryte commentedon May 11, 2021
@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.
kgryte commentedon May 11, 2021
Cross-linking to a discussion regarding issues concerning out-of-bounds access in
take
APIs for accelerator libraries.thomasjpfan commentedon Dec 7, 2021
In the ML use case, it is common to want to sample with replacement or shuffle a dataset. This is commonly done by sampling an integer array and using it to subset the dataset:
For libraries that need selection with integer arrays, a work around is to implement
take
:Note that sampling with replacement can not be done with a boolean mask, because some rows may be selected twice.
leofang commentedon Feb 1, 2022
Hi @kmaehashi @asi1024 @emcastillo FYI. In a recent array API call we discussed about the proposed take/put APIs, and there were questions regarding how CuPy currently implements these functions, as there could be data/value dependency and people were wondering if we just have to pay the synchronization cost to ensure the behavior is correct. Could you help address? Thanks! (And sorry I dropped the ball here...)
shoyer commentedon Mar 10, 2022
+1 I think "array only" integer indexing would be quite well defined, and would not be problematic for accelerators. The main challenge with NumPy's implementation of "advanced indexing" is handling mixed integer/slice/boolean cases.
rgommers commentedon Mar 24, 2022
Here is a summary of today's discussion:
take
is fine, there's no problem for accelerators and all libraries listed above already have this API. Given that they all have it, there's no problem addingtake
to the standard right now.__getitem__
part of indexing is equivalent totake
. However, as @asmeurer pointed out, it would be odd to add support for integer array indexing in__getitem__
but not in__setitem__
. Hence we need to look at the latter.put
and__setitem__
are also equivalent - and more problematic, for multiple reasons:put
isn't widely supported across libraries, and not with the same name either.put
is explicitly an in-place function in NumPy et al., which is a problem for JAX/TensorFlow. Having a better handle on the topic of mutability looks like a hard requirement before even considering an in-place function likeput
.put
to the standard. However, that's a new function that libraries don't yet have (actually some do under names likeindex_put
, but it's a mixed bag). And it's not clear that this would be preferred in the long term; an inplaceput
that is guaranteed to raise when it crosses paths with a view may be better.Given all that, the proposal is to only add
take
now, and revisit integer array indexing andput
in the future.49 remaining items