Skip to content

LigerGEGLUMLP error with torch.compile #96

Open
@Luke-Chesley

Description

@Luke-Chesley

🐛 Describe the bug

When using ligergeglumlp with torch complie i get the following error.

 UserWarning: Traceback (most recent call last):
Encountered an exception in identify_mutated_tensors, assuming every input is mutated:
  File "~/.local/lib/python3.10/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 595, in identify_mutated_tensors
    ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
Encountered an exception in identify_mutated_tensors, assuming every input is mutated:
  File "~/.local/lib/python3.10/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 115, in generate_ttir
    raise Exception("Incorrect number of arguments passed to kernel")
Encountered an exception in identify_mutated_tensors, assuming every input is mutated:
Exception: Incorrect number of arguments passed to kernel

BackendCompilerFailed: backend='inductor' raised:
CompilationError: at 21:18:
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
    b_row = tl.load(b + col_offsets, mask=mask, other=0)

    # tanh approximation form of GELU is computed with:
    # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
    sqrt_2_over_pi = 0.7978845608028654  # sqrt(2 / pi)
    a_cubed = a_row * a_row * a_row
    tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
    tanh_result = tanh(tanh_arg)
                  ^
NameError('tanh is not defined')

I can run it with the torch.compile line commented out with no problem.

Reproduce

from liger_kernel.transformers.geglu import LigerGEGLUMLP
from dataclasses import dataclass
import torch
torch.set_float32_matmul_precision('high')

@dataclass
class Config:
    hidden_size: int = 768
    intermediate_size: int = 768 * 4
    hidden_act: str = 'gelu_pytorch_tanh'

cfg = Config()

gegelu = LigerGEGLUMLP(cfg).cuda()

gegelu = torch.compile(gegelu)

x = torch.randn(1,1024,768,device='cuda')

Versions

Python Version: 3.10.12
CUDA Version: 12.1
PyTorch Version: 2.3.0+cu121
Triton Version: 2.3.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions