Skip to content

Fix CUDA kernel index data type in deeplearning/fbgemm/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu +10 #3846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

r-barnes
Copy link
Contributor

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/936

CUDA kernel variables matching the type (thread|block|grid).(Idx|Dim).(x|y|z) have the data type uint.

Many programmers mistakenly use implicit casts to turn these data types into int. In fact, the CUDA Programming Guide it self is inconsistent and incorrect in its use of data types in programming examples.

The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items.

While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them.

Reviewed By: sryap, dtolnay

Differential Revision: D71355405

…/embedding_inplace_ops/embedding_inplace_update.cu +10

Summary:
X-link: facebookresearch/FBGEMM#936

CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables).

Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples.

The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items.

While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them.

Reviewed By: sryap, dtolnay

Differential Revision: D71355405
Copy link

netlify bot commented Mar 18, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 58fd4d6
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/67d9e16d03e71c00080e4e55
😎 Deploy Preview https://deploy-preview-3846--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71355405

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 5c7c234.

q10 pushed a commit to q10/FBGEMM that referenced this pull request Apr 10, 2025
…/embedding_inplace_ops/embedding_inplace_update.cu +10 (pytorch#936)

Summary:
X-link: pytorch#3846

Pull Request resolved: facebookresearch/FBGEMM#936

CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables).

Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples.

The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items.

While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them.

Reviewed By: sryap, dtolnay

Differential Revision: D71355405

fbshipit-source-id: a11a15fd70f24250604c2a9c3951f34123f70679
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants