@@ -55,7 +55,6 @@ def swiglu(x, y=None):
55
55
)
56
56
except :
57
57
pass
58
- from paddle .utils import try_import
59
58
60
59
from paddlenlp .transformers .conversion_utils import (
61
60
StateDictNameMapping ,
@@ -81,14 +80,16 @@ def swiglu(x, y=None):
81
80
82
81
try :
83
82
if get_env_device () == "npu" :
84
- from paddle .base import core
85
83
86
84
for lib in os .listdir (os .getenv ("CUSTOM_DEVICE_ROOT" )):
87
85
if lib .endswith (".so" ):
88
86
paddle .utils .cpp_extension .extension_utils .load_op_meta_info_and_register_op (lib )
89
87
from paddle .nn .functional .flash_attention import flash_attention
90
88
except :
91
89
flash_attention = None
90
+ from . import fusion_ops
91
+
92
+ rms_norm_fused = fusion_ops .rms_norm_fused
92
93
93
94
__all__ = [
94
95
"LlamaModel" ,
@@ -215,67 +216,22 @@ def scaled_dot_product_attention(
215
216
_ , kv_seq_len , _ , _ = value_states .shape
216
217
217
218
if config .use_flash_attention and flash_attention :
219
+ return fusion_ops .fusion_flash_attention (
220
+ query_states ,
221
+ config ,
222
+ key_states ,
223
+ value_states ,
224
+ attention_mask ,
225
+ output_attentions ,
226
+ alibi ,
227
+ sequence_parallel ,
228
+ reshard_layer ,
229
+ npu_is_casual ,
230
+ )
231
+
218
232
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219
233
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
220
234
221
- version = paddle .version .full_version
222
- if version != "0.0.0" and version <= "2.5.2" :
223
- if alibi is not None :
224
- raise ValueError ("Flash Attention doesn't support alibi" )
225
- attn_output , attn_weights = flash_attention (
226
- query_states ,
227
- key_states ,
228
- value_states ,
229
- causal = True ,
230
- return_softmax = output_attentions ,
231
- )
232
- else :
233
- if alibi is not None :
234
- alibi = alibi .reshape ([bsz , num_heads , 1 , - 1 ])
235
- attention_mask = attention_mask .cast (alibi .dtype ) + alibi
236
- if get_env_device () == "npu" :
237
- attn_output = core .eager ._run_custom_op (
238
- "flash_attention_npu" ,
239
- query_states ,
240
- key_states ,
241
- value_states ,
242
- None ,
243
- attention_mask ,
244
- 0.0 ,
245
- attention_mask is None ,
246
- True ,
247
- False ,
248
- npu_is_casual ,
249
- )[0 ]
250
- else :
251
- attn_output = F .scaled_dot_product_attention (
252
- query_states ,
253
- key_states ,
254
- value_states ,
255
- attn_mask = attention_mask ,
256
- is_causal = attention_mask is None ,
257
- )
258
- attn_weights = None
259
-
260
- if reshard_layer is not None :
261
- # attn_output shape: [bs, seqlen, num_head/sep, head_dim]
262
- attn_output = reshard_layer (
263
- attn_output ,
264
- split_axis = 1 ,
265
- concat_axis = 2 ,
266
- )
267
- # attn_output shape: [bs, seqlen/sep, num_head, head_dim]
268
- assert (
269
- config .sep_parallel_degree > 1 and q_len % config .sep_parallel_degree == 0
270
- ), f"q_len:{ q_len } , config.sep_parallel_degree:{ config .sep_parallel_degree } "
271
- q_len = q_len // config .sep_parallel_degree
272
- num_heads = num_heads * config .sep_parallel_degree
273
-
274
- if sequence_parallel :
275
- attn_output = attn_output .reshape ([bsz * q_len , head_dim * num_heads ])
276
- else :
277
- attn_output = attn_output .reshape ([bsz , q_len , head_dim * num_heads ])
278
- return (attn_output , attn_weights ) if output_attentions else attn_output
279
235
else :
280
236
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
281
237
query_states = paddle .transpose (query_states , [0 , 2 , 1 , 3 ])
@@ -385,11 +341,6 @@ def _expand_2d_mask(mask, dtype, tgt_length):
385
341
return expanded_mask
386
342
387
343
388
- def rms_norm_fused (x_in , w , eps ):
389
- fused_ln = try_import ("fused_ln" )
390
- return fused_ln .fused_rms_norm (x_in , w , eps )[0 ]
391
-
392
-
393
344
class LlamaRMSNorm (nn .Layer ):
394
345
def __init__ (self , config ):
395
346
super ().__init__ ()
@@ -407,18 +358,7 @@ def __init__(self, config):
407
358
408
359
def forward (self , hidden_states ):
409
360
if self .config .use_fused_rms_norm :
410
- if get_env_device () == "npu" :
411
- return core .eager ._run_custom_op ("rms_norm_npu" , hidden_states , self .weight , self .variance_epsilon )[0 ]
412
- elif get_env_device () == "xpu" :
413
- try :
414
- import paddle_xpu_nn # noqa: F821
415
-
416
- return paddle_xpu_nn .xpu_rms_norm (hidden_states , self .weight , self .variance_epsilon )[0 ]
417
- except ImportError :
418
- raise NotImplementedError (
419
- f"Implementation of fused_rms_norm is not available on { get_env_device ()} . Please install paddle_xpu to use this feature"
420
- )
421
- return rms_norm_fused (hidden_states , self .weight , self .variance_epsilon )
361
+ return fusion_ops .fusion_rms_norm (hidden_states , self .weight , self .variance_epsilon )
422
362
423
363
if paddle .in_dynamic_mode ():
424
364
with paddle .amp .auto_cast (False ):
@@ -974,45 +914,16 @@ def forward(
974
914
batch_size , seq_length , _ , _ = query_states .shape
975
915
position_ids = paddle .arange (seq_length , dtype = "int64" ).expand ((batch_size , seq_length ))
976
916
if self .use_fused_rope :
977
- assert past_key_value is None , "fuse rotary not support cache kv for now"
978
- batch_size , seq_length , num_heads , head_dim = query_states .shape
979
- _ , kv_seq_len , num_key_value_heads , _ = key_states .shape
980
- cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
981
- if get_env_device () == "npu" :
982
- query_states = core .eager ._run_custom_op ("fused_rope" , query_states , cos , sin )[0 ]
983
- key_states = core .eager ._run_custom_op ("fused_rope" , key_states , cos , sin )[0 ]
984
- else :
985
- # paddle version > 2.6 or develop support q and k/v with different num_heads
986
- paddle_version = float (paddle .__version__ [:3 ])
987
- if ((paddle_version != 0.0 ) and (paddle_version <= 2.6 )) and (num_heads != num_key_value_heads ):
988
- query_states , _ , _ = fused_rotary_position_embedding (
989
- query_states ,
990
- None ,
991
- None ,
992
- sin = sin ,
993
- cos = cos ,
994
- position_ids = position_ids ,
995
- use_neox_rotary_style = False ,
996
- )
997
- key_states , _ , _ = fused_rotary_position_embedding (
998
- key_states ,
999
- None ,
1000
- None ,
1001
- sin = sin ,
1002
- cos = cos ,
1003
- position_ids = position_ids ,
1004
- use_neox_rotary_style = False ,
1005
- )
1006
- else :
1007
- query_states , key_states , _ = fused_rotary_position_embedding (
1008
- query_states ,
1009
- key_states ,
1010
- v = None ,
1011
- sin = sin ,
1012
- cos = cos ,
1013
- position_ids = position_ids ,
1014
- use_neox_rotary_style = False ,
1015
- )
917
+ query_states , key_states = fusion_ops .fusion_rope (
918
+ query_states ,
919
+ key_states ,
920
+ value_states ,
921
+ hidden_states ,
922
+ position_ids ,
923
+ past_key_value ,
924
+ self .rotary_emb ,
925
+ )
926
+
1016
927
else :
1017
928
if self .config .use_long_sequence_strategies :
1018
929
cos , sin = self .rotary_emb (seq_len = kv_seq_len )
0 commit comments