Skip to content

In-place operations in triton kernel might result in incorrect gradient calculations #272

Open
@Tcc0403

Description

@Tcc0403

🐛 Describe the bug

#254 #262 (comments)

PyTorch’s autograd system records operations on tensors to construct a computational graph, which is used for computing gradients. When an in-place operation is performed on a tensor, the autograd system needs to ensure that the computational graph reflects the modified values.

https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks

Each tensor in PyTorch has an internal version counter that is incremented every time an in-place operation is performed.

https://github.com/pytorch/pytorch/blob/190e09d8b6a13f789b143f0fbd1325f924550967/c10/core/TensorImpl.h#L382

Since we don't explicitly call pytorch in-place operations, the version counter doesn't change when we do in-place operations in triton kernels, i.e., pytorch's "In-place correctness checks" mechanism won't work properly and show no error to user.

Reproduce

import torch
import torch.nn.functional as F

from liger_kernel.transformers.functional import liger_cross_entropy


def run_inplace_experiment(logits_p, logits_q, cross_entropy_fn):
    _p = logits_p.clone().detach().requires_grad_(True)
    _p.retain_grad()
    softmax = torch.nn.Softmax(dim=-1)
    p = softmax(_p)
    p.retain_grad()
    loss = cross_entropy_fn(p, logits_q)
    loss.backward(retain_graph=True)

    print(f"Cross Entropy Loss: {loss.item()}")
    print(f"Input _p: {_p}")
    print(f"Input logits_q: {logits_q}")
    print(f"Gradients of p (batch item 0): {p.grad[0]}")
    print(f"Gradients of _p (batch item 0): {_p.grad[0]}")


torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)


run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=F.cross_entropy)

print()
print("LIGER:")
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)
python3 inplace_bug.py
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017,  0.0029,  0.0003,  0.0055, -0.0182,  0.0024,  0.0023,  0.0032],
       device='cuda:0')

LIGER:
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([2.1320e-05, 3.4830e-05, 6.8024e-06, 6.7467e-05, 1.3247e-02, 2.9687e-05,
        2.8429e-05, 3.8656e-05], device='cuda:0')

Solution

One trivial solution is performing a no-op like inplace operation, such as .add_(0) and .mul_(1), to explicitly declare we have changed the tensor values in-place, then the errors will be raised.

With this approach, I suggest adding a inplace=True/False parameter to those functions involved in-place operations, so users can set it to False (using extra tensors) when getting errors.

Versions

Environment Report:

Operating System: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version: 2.4.1+cu121
CUDA version: 12.1
Triton version: 3.0.0
Transformers version: 4.45.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions