Skip to content

Commit fd2fc9d

Browse files
committed
[LayerNorm] Don't let torch.compile trace inside _layer_norm_bwd
1 parent 6ba57ef commit fd2fc9d

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

flash_attn/ops/triton/layer_norm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,9 @@ def _layer_norm_bwd(
696696

697697

698698
@triton_op("flash_attn::layer_norm_bwd_impl", mutates_args={},
699-
schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)")
699+
schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)",
700+
allow_decomposition=False, # Don't let torch.compile trace inside
701+
)
700702
def _layer_norm_bwd_impl(
701703
dy: Tensor,
702704
x: Tensor,
@@ -718,12 +720,14 @@ def _layer_norm_bwd_impl(
718720
is_rms_norm: bool = False,
719721
x_dtype: Optional[torch.dtype] = None,
720722
recompute_output: bool = False,
721-
):
723+
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
722724
M, N = x.shape
723725
assert x.stride(-1) == 1
726+
dy = maybe_contiguous_lastdim(dy)
724727
assert dy.stride(-1) == 1
725728
assert dy.shape == (M, N)
726729
if dresidual is not None:
730+
dresidual = maybe_contiguous_lastdim(dresidual)
727731
assert dresidual.stride(-1) == 1
728732
assert dresidual.shape == (M, N)
729733
assert weight.shape == (N,)
@@ -732,6 +736,7 @@ def _layer_norm_bwd_impl(
732736
assert bias.stride(-1) == 1
733737
assert bias.shape == (N,)
734738
if dy1 is not None:
739+
dy1 = maybe_contiguous_lastdim(dy1)
735740
assert weight1 is not None
736741
assert dy1.shape == dy.shape
737742
assert dy1.stride(-1) == 1
@@ -946,16 +951,15 @@ def forward(
946951
def backward(ctx, dy, *args):
947952
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
948953
dy = dy.reshape(-1, dy.shape[-1])
949-
dy = maybe_contiguous_lastdim(dy)
950954
if weight1 is not None:
951955
dy1, args = args[0], args[1:]
952-
dy1 = maybe_contiguous_lastdim(dy1.reshape(-1, dy1.shape[-1]))
956+
dy1 = dy1.reshape(-1, dy1.shape[-1])
953957
assert dy1.shape == x.shape
954958
else:
955959
dy1 = None
956960
if ctx.prenorm:
957961
dresidual = args[0]
958-
dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1]))
962+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
959963
assert dresidual.shape == x.shape
960964
else:
961965
dresidual = None

flash_attn/utils/library.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ def triton_op(
1414
*,
1515
mutates_args: Union[str, Iterable[str]],
1616
schema: Optional[str] = None,
17+
# If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False,
18+
# then it behaves like torch.library.custom_op instead, which doesn't decompose the operator
19+
# and so inductor can't trace inside.
20+
allow_decomposition=True,
1721
) -> Callable:
1822
def dec(fn: Callable[..., object]) -> CustomOpDef:
1923
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
@@ -35,23 +39,25 @@ def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
3539
# so we can just register it as the Fake/meta kernel.
3640
result.register_fake(fn)
3741

38-
# We decompose the operator when FunctionalTensorMode is active.
39-
# The goal is to decompose the operator in AOTDispatcher.
40-
# - With torch.compile, this means that the backend (usually Inductor)
41-
# can see a call to the triton kernel(s) and so it can directly optimize
42-
# them by inlining them into the lowering process.
43-
def functional_decomp( # type: ignore[no-untyped-def]
44-
mode, op, types, args, kwargs
45-
):
46-
from torch.export._trace import custom_triton_ops_decomposition_disabled
47-
48-
if custom_triton_ops_decomposition_disabled():
49-
return mode.__torch_dispatch__(op, types, args, kwargs)
50-
else:
51-
with mode:
52-
return fn(*args, **kwargs)
53-
54-
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
42+
if allow_decomposition:
43+
# We decompose the operator when FunctionalTensorMode is active.
44+
# The goal is to decompose the operator in AOTDispatcher.
45+
# - With torch.compile, this means that the backend (usually Inductor)
46+
# can see a call to the triton kernel(s) and so it can directly optimize
47+
# them by inlining them into the lowering process.
48+
def functional_decomp( # type: ignore[no-untyped-def]
49+
mode, op, types, args, kwargs
50+
):
51+
from torch.export._trace import custom_triton_ops_decomposition_disabled
52+
53+
if custom_triton_ops_decomposition_disabled():
54+
return mode.__torch_dispatch__(op, types, args, kwargs)
55+
else:
56+
with mode:
57+
return fn(*args, **kwargs)
58+
59+
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
60+
5561
return result
5662

5763
if fn is None:

0 commit comments

Comments
 (0)