Description
🐛 Describe the bug
PyTorch’s autograd system records operations on tensors to construct a computational graph, which is used for computing gradients. When an in-place operation is performed on a tensor, the autograd system needs to ensure that the computational graph reflects the modified values.
https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks
Each tensor in PyTorch has an internal version counter that is incremented every time an in-place operation is performed.
Since we don't explicitly call pytorch in-place operations, the version counter doesn't change when we do in-place operations in triton kernels, i.e., pytorch's "In-place correctness checks" mechanism won't work properly and show no error to user.
Reproduce
import torch
import torch.nn.functional as F
from liger_kernel.transformers.functional import liger_cross_entropy
def run_inplace_experiment(logits_p, logits_q, cross_entropy_fn):
_p = logits_p.clone().detach().requires_grad_(True)
_p.retain_grad()
softmax = torch.nn.Softmax(dim=-1)
p = softmax(_p)
p.retain_grad()
loss = cross_entropy_fn(p, logits_q)
loss.backward(retain_graph=True)
print(f"Cross Entropy Loss: {loss.item()}")
print(f"Input _p: {_p}")
print(f"Input logits_q: {logits_q}")
print(f"Gradients of p (batch item 0): {p.grad[0]}")
print(f"Gradients of _p (batch item 0): {_p.grad[0]}")
torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=F.cross_entropy)
print()
print("LIGER:")
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)
❯ python3 inplace_bug.py
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438, 0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
[-1.0745, -0.3631, -1.6711, 2.2655, 0.3117, -0.1842, 1.2866, 1.1820],
[-0.1271, 1.2169, 1.4353, 1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
[ 0.0697, -0.0074, 1.8969, 0.6878, -0.0779, -0.8373, 1.3506, -0.2879],
[-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504, 0.5435, 1.5150],
[ 0.0141, 0.4532, 1.6349, 0.7124, -0.1806, 1.0252, -1.4622, -0.7554],
[-0.1836, 0.3824, 0.3918, -0.0830, 0.8971, -1.1123, 0.1116, 0.4863],
[-0.5499, -0.3231, -0.5469, 0.9049, 0.2837, 0.1210, 0.4730, -1.0823]],
device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149, 0.0157, 0.0140, 0.0174, -0.1086, 0.0154, 0.0153, 0.0159],
device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017, 0.0029, 0.0003, 0.0055, -0.0182, 0.0024, 0.0023, 0.0032],
device='cuda:0')
LIGER:
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438, 0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
[-1.0745, -0.3631, -1.6711, 2.2655, 0.3117, -0.1842, 1.2866, 1.1820],
[-0.1271, 1.2169, 1.4353, 1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
[ 0.0697, -0.0074, 1.8969, 0.6878, -0.0779, -0.8373, 1.3506, -0.2879],
[-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504, 0.5435, 1.5150],
[ 0.0141, 0.4532, 1.6349, 0.7124, -0.1806, 1.0252, -1.4622, -0.7554],
[-0.1836, 0.3824, 0.3918, -0.0830, 0.8971, -1.1123, 0.1116, 0.4863],
[-0.5499, -0.3231, -0.5469, 0.9049, 0.2837, 0.1210, 0.4730, -1.0823]],
device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149, 0.0157, 0.0140, 0.0174, -0.1086, 0.0154, 0.0153, 0.0159],
device='cuda:0')
Gradients of _p (batch item 0): tensor([2.1320e-05, 3.4830e-05, 6.8024e-06, 6.7467e-05, 1.3247e-02, 2.9687e-05,
2.8429e-05, 3.8656e-05], device='cuda:0')
Solution
One trivial solution is performing a no-op like inplace operation, such as .add_(0)
and .mul_(1)
, to explicitly declare we have changed the tensor values in-place, then the errors will be raised.
With this approach, I suggest adding a inplace=True/False
parameter to those functions involved in-place operations, so users can set it to False (using extra tensors) when getting errors.
Versions
Environment Report:
Operating System: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version: 2.4.1+cu121
CUDA version: 12.1
Triton version: 3.0.0
Transformers version: 4.45.0