Skip to content

RFC: add APIs for getting elements via a list of indices (i.e., take, take_along_axis, etc) #177

Open
@kgryte

Description

@kgryte
Contributor

Proposal

Add APIs for getting and setting elements via a list of indices.

Motivation

Currently, the array API specification does not provide a direct means of extracting and setting a list of elements along an axis. Such operations are relatively common in NumPy usage either via "fancy indexing" or via explicit take and put APIs.

Two main arguments come to mind for supporting at least basic take and put APIs:

  1. Indexing does not currently support providing a list of indices to index into an array. The principal reason for not supporting fancy indexing stems from dynamic shapes and compatibility with accelerator libraries. However, use of fancy indexing is relatively common in NumPy and similar libraries where dynamically extracting rows/cols/values is possible and can be readily implemented.

  2. Currently, the output of a subset of APIs currently included in the standard cannot be readily consumed without manual workarounds if a specification-conforming library implemented only the APIs in the standard. For example,

    • argsort returns an array of indices. In NumPy, the output of this function can be consumed by put_along_axis and take_along_axis.
    • unique can return an array of indices if return_index is True.

Background

The following table summarizes library implementations of such APIs:

op NumPy CuPy Dask MXNet Torch TensorFlow
extracting elements along axis take take take take take/gather gather/numpy.take
setting elements along axis put put -- -- scatter scatter_nd/tensor_scatter_nd_update
extracting elements over matching 1d slices take_along_axis take_along_axis -- -- -- gather_nd/numpy.take_alongaxis
setting elements over matching 1d slices put_along_axis -- -- -- -- --

While most libraries implement some form of take, fewer implement other complementary APIs.

Activity

rgommers

rgommers commented on May 6, 2021

@rgommers
Member

Thanks @kgryte. A few initial thoughts:

  • There is also overlap between take/put and various scatter/gather functions in TensorFlow, PyTorch and MXNet. There's a whole host of those functions.
  • Is there really an issue with shape determinism? I'm probably missing something here, but isn't the output size along the given dimension equal to indices.size? And put is an inplace operation which doesn't change the shape.
  • If we do want to add these, we may consider putting them in the second version of the API. Just thinking that we should at some point stop making the API a permanently moving target.
kgryte

kgryte commented on May 6, 2021

@kgryte
ContributorAuthor

@rgommers Thanks for the comments.

  1. Correct. I've updated the table with Torch and TF scatter and gather methods.
  2. Correct me if I am wrong, but indices.size need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.
  3. Not opposed to delaying until V2 (2022).
asmeurer

asmeurer commented on May 6, 2021

@asmeurer
Member

A natural question is if take is supported, is there any reason equivalent indexing shouldn't also be supported. Granted, take only represents a specific subset of general (NumPy) integer array indexing, where indexing is done on a single axis.

kgryte

kgryte commented on May 6, 2021

@kgryte
ContributorAuthor

@asmeurer I think that take would be an optional API; whereas indexing semantics should be universal.

rgommers

rgommers commented on May 7, 2021

@rgommers
Member

Correct me if I am wrong, but indices.size need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.

If the size of indices is variable, it's the function that produces indices that is data-dependent. take itself however is not. Compare with boolean indexing or nonzero, there the output size is in the range [0, x_input.size]; for take it's always x_input.size.

kgryte

kgryte commented on May 11, 2021

@kgryte
ContributorAuthor

@rgommers Correct; however, I still could imagine that data flows involving a take operation may still be problematic for AOT computational graphs. While the output size is indices.size, an array library may not be able to statically allocate memory for the output of the take operation. This said, accelerator libraries do manage to support similar APIs (e.g., scatter/gather), so probably no need to further belabor this.

kgryte

kgryte commented on May 11, 2021

@kgryte
ContributorAuthor

@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.

kgryte

kgryte commented on May 11, 2021

@kgryte
ContributorAuthor

Cross-linking to a discussion regarding issues concerning out-of-bounds access in take APIs for accelerator libraries.

added this to the v2022 milestone on Oct 4, 2021
thomasjpfan

thomasjpfan commented on Dec 7, 2021

@thomasjpfan

In the ML use case, it is common to want to sample with replacement or shuffle a dataset. This is commonly done by sampling an integer array and using it to subset the dataset:

import numpy.array_api as xp

X = xp.asarray([[1, 2, 3, 4], [2, 3, 4, 5],
                [4, 5, 6, 10], [5, 6, 8, 20]], dtype=xp.float64)

sample_indices = xp.asarray([0, 0, 1, 3])

# Does not work
# X[sample_indices, :]

For libraries that need selection with integer arrays, a work around is to implement take:

def take(X, indices, *, axis):
    # Simple implementation that only works for axis in {0, 1}
    if axis == 0:
        selected = [X[i] for i in indices]
    else:  # axis == 1
        selected = [X[:, i] for i in indices]
    return xp.stack(selected, axis=axis)

take(X, sample_indices, axis=0)

Note that sampling with replacement can not be done with a boolean mask, because some rows may be selected twice.

leofang

leofang commented on Feb 1, 2022

@leofang
Contributor

Hi @kmaehashi @asi1024 @emcastillo FYI. In a recent array API call we discussed about the proposed take/put APIs, and there were questions regarding how CuPy currently implements these functions, as there could be data/value dependency and people were wondering if we just have to pay the synchronization cost to ensure the behavior is correct. Could you help address? Thanks! (And sorry I dropped the ball here...)

shoyer

shoyer commented on Mar 10, 2022

@shoyer
Contributor

@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.

+1 I think "array only" integer indexing would be quite well defined, and would not be problematic for accelerators. The main challenge with NumPy's implementation of "advanced indexing" is handling mixed integer/slice/boolean cases.

rgommers

rgommers commented on Mar 24, 2022

@rgommers
Member

Here is a summary of today's discussion:

  • Implementing take is fine, there's no problem for accelerators and all libraries listed above already have this API. Given that they all have it, there's no problem adding take to the standard right now.
  • The __getitem__ part of indexing is equivalent to take. However, as @asmeurer pointed out, it would be odd to add support for integer array indexing in __getitem__ but not in __setitem__. Hence we need to look at the latter.
  • put and __setitem__ are also equivalent - and more problematic, for multiple reasons:
    • as the table in the issue description shows, put isn't widely supported across libraries, and not with the same name either.
    • put is explicitly an in-place function in NumPy et al., which is a problem for JAX/TensorFlow. Having a better handle on the topic of mutability looks like a hard requirement before even considering an in-place function like put.
    • @oleksandr-pavlyk suggested adding a new out of place version of put to the standard. However, that's a new function that libraries don't yet have (actually some do under names like index_put, but it's a mixed bag). And it's not clear that this would be preferred in the long term; an inplace put that is guaranteed to raise when it crosses paths with a view may be better.

Given all that, the proposal is to only add take now, and revisit integer array indexing and put in the future.

49 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @asmeurer@rgommers@shoyer@kgryte@lezcano

        Issue actions

          RFC: add APIs for getting elements via a list of indices (i.e., `take`, `take_along_axis`, etc) · Issue #177 · data-apis/array-api