@@ -696,7 +696,9 @@ def _layer_norm_bwd(
696
696
697
697
698
698
@triton_op ("flash_attn::layer_norm_bwd_impl" , mutates_args = {},
699
- schema = "(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)" )
699
+ schema = "(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)" ,
700
+ allow_decomposition = False , # Don't let torch.compile trace inside
701
+ )
700
702
def _layer_norm_bwd_impl (
701
703
dy : Tensor ,
702
704
x : Tensor ,
@@ -718,12 +720,14 @@ def _layer_norm_bwd_impl(
718
720
is_rms_norm : bool = False ,
719
721
x_dtype : Optional [torch .dtype ] = None ,
720
722
recompute_output : bool = False ,
721
- ):
723
+ ) -> ( Tensor , Tensor , Tensor , Tensor , Tensor , Tensor , Tensor , Tensor ) :
722
724
M , N = x .shape
723
725
assert x .stride (- 1 ) == 1
726
+ dy = maybe_contiguous_lastdim (dy )
724
727
assert dy .stride (- 1 ) == 1
725
728
assert dy .shape == (M , N )
726
729
if dresidual is not None :
730
+ dresidual = maybe_contiguous_lastdim (dresidual )
727
731
assert dresidual .stride (- 1 ) == 1
728
732
assert dresidual .shape == (M , N )
729
733
assert weight .shape == (N ,)
@@ -732,6 +736,7 @@ def _layer_norm_bwd_impl(
732
736
assert bias .stride (- 1 ) == 1
733
737
assert bias .shape == (N ,)
734
738
if dy1 is not None :
739
+ dy1 = maybe_contiguous_lastdim (dy1 )
735
740
assert weight1 is not None
736
741
assert dy1 .shape == dy .shape
737
742
assert dy1 .stride (- 1 ) == 1
@@ -946,16 +951,15 @@ def forward(
946
951
def backward (ctx , dy , * args ):
947
952
x , weight , bias , weight1 , bias1 , rowscale , seeds , mean , rstd = ctx .saved_tensors
948
953
dy = dy .reshape (- 1 , dy .shape [- 1 ])
949
- dy = maybe_contiguous_lastdim (dy )
950
954
if weight1 is not None :
951
955
dy1 , args = args [0 ], args [1 :]
952
- dy1 = maybe_contiguous_lastdim ( dy1 .reshape (- 1 , dy1 .shape [- 1 ]) )
956
+ dy1 = dy1 .reshape (- 1 , dy1 .shape [- 1 ])
953
957
assert dy1 .shape == x .shape
954
958
else :
955
959
dy1 = None
956
960
if ctx .prenorm :
957
961
dresidual = args [0 ]
958
- dresidual = maybe_contiguous_lastdim ( dresidual .reshape (- 1 , dresidual .shape [- 1 ]) )
962
+ dresidual = dresidual .reshape (- 1 , dresidual .shape [- 1 ])
959
963
assert dresidual .shape == x .shape
960
964
else :
961
965
dresidual = None
0 commit comments