Skip to content

RFC: add topk and / or argpartition #629

Open
@ogrisel

Description

@ogrisel

numpy provides an indirect way to compute the indices of the smallest (or largest) values of an array using: numpy.argpartition.

There is also a proposal to provide a higher level API, namely (arg)topk in numpy:

This PR relies on numpy.argpartition internally but it can probably later be optimized to avoid allocating a result array of the size of the input array when k is small.

Here is a quick review of some available implementations in related libraries:

  • torch.topk (no such thing as torch.argpartition)
    • returns a tuple of values and indices
  • jax.lax.top_k
    • returns a tuple of values and indices
    • apparently it is quite slow on GPU
  • dask.array.topk
    • returns only the values, I did not find a way to get the indices :(
  • cupy.argpartition but internally computes a full cupy.argsort which makes it very inefficient for large arrays and small k: O(nlog(n)) instead of O(n).

Motivation: (arg)topk is needed by popular baseline data-science workloads (e.g. k-nearest neighbors classification in scikit-learn) and is surprisingly non trivial to implement efficiently. For instance on GPUs, the fastest implementations are based on some kind of partial radix sort while CPU implementations would use more traditional partial sorting algorithms (as implemented in std:partial_sort or std::nth_element).

Activity

ogrisel

ogrisel commented on May 17, 2023

@ogrisel
Author

Note: since argsort is part of the standard Array API, it would be possible to implement a generic yet inefficient fallback in array-api-compat while allowing to dispatch to a more efficient routine for libraries that provide it. This is what cupy.argpartition does for instance.

rgommers

rgommers commented on May 17, 2023

@rgommers
Member

Thanks for the proposal @ogrisel. It's actually surprising that coverage and performance across array libraries is so spotty. I dug up the NumPy mailing list discussion, and it seemed more or less positive, just unfinished and the name to use is a nicely-sized bikeshed.

Is this function something you already have in scikit-learn internally, or are you looking for something more efficient than the argsort or similar function you use now?

ogrisel

ogrisel commented on May 17, 2023

@ogrisel
Author

In scikit-learn, for k-nearest neighbors (bruteforce exact method for medium to high dimensional space), we use a routine optimized for multicore CPUs using Cython + OpenMP for pairwise distance (similar to scipy's cdist) fused with a topk reduction implemented in templated Cython. The topk reduction itself (called "argkmin" in scikit-learn) lives here:

This code can only be called as a reduction fused into the multithreaded pairwise distance computation kernel. It is orchestrated via:

For CPU, I doubt than any Array API based solution will be able to compete both on speed and memory usage.

However, we are interested in implementing Array API support for an alternative numpy code-path in order to provide GPU support, e.g. via PyTorch or CuPy. The reducer used in the numpy code-path is there:

It's based on numpy.argpartition followed by numpy.argsort of the top k values.

Note that to efficiently implement k-nearest neighbors in scikit-learn using the Array API, we would also need the Array API to provide scipy.spatial.distance.cdist.

I have not open an issue to discuss cdist yet. I wanted to probe the waters with topk first.

shoyer

shoyer commented on Jun 1, 2023

@shoyer
Contributor

JAX also has an approximate top-k implementation specifically tuned for TPUs: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.approx_max_k.html

ogrisel

ogrisel commented on Jun 15, 2023

@ogrisel
Author

I am not sure if we want to include non-exact methods in the spec. I have the feeling that there are many ways to compute such approximations and that they will require different and evolving parametrizations with different speed-accuracy trade-offs.

kgryte

kgryte commented on Dec 14, 2023

@kgryte
Contributor

A PR has now been opened which proposes adding top_k and friends to the specification: #722. Please feel free to review and comment there with your concerns and feedback.

added this to the v2024 milestone on Jan 25, 2024
changed the title [-]Standard API for topk and / or argpartition[/-] [+]RFC: add topk and / or argpartition[/+] on Apr 4, 2024
added
RFCRequest for comments. Feature requests and proposed changes.
on Apr 4, 2024
jakirkham

jakirkham commented on Jan 18, 2025

@jakirkham
Member

dask.array.topk
returns only the values, I did not find a way to get the indices :(

There is dask.array.argtopk, which simply returns the indices for the values topk would return

kgryte

kgryte commented on Jan 18, 2025

@kgryte
Contributor

@jakirkham Would be good to get your feedback on #722.

modified the milestones: v2024, v2025 on Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    API extensionAdds new functions or objects to the API.Needs DiscussionNeeds further discussion.RFCRequest for comments. Feature requests and proposed changes.

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

      Development

      Participants

      @ogrisel@rgommers@shoyer@kgryte@jakirkham

      Issue actions

        RFC: add topk and / or argpartition · Issue #629 · data-apis/array-api