Skip to content

Commit 05acad5

Browse files
authored
[NPU]Custom fusion operator unification (#8431)
* update * add llama-npu-opt-script * Update dev_opt_lora.sh * Update dev_opt_ppt.sh * Update dev_opt_lora.sh * Update dev_opt_ppt.sh * Update dev_opt_sft.sh * Rename dev_opt_lora.sh to llama_npu_opt_lora.sh * Update dev_opt_ppt.sh * Rename dev_opt_ppt.sh to llama_npu_opt_ppt.sh * Update llama_npu_opt_lora.sh * Update and rename dev_opt_sft.sh to llama_npu_opt_sft.sh * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * add funsion ops * update * Update fusion_ops.py * update
1 parent 53ad2da commit 05acad5

File tree

2 files changed

+216
-116
lines changed

2 files changed

+216
-116
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import paddle
18+
import paddle.nn.functional as F
19+
20+
try:
21+
from paddle.incubate.nn.functional import fused_rotary_position_embedding
22+
except ImportError:
23+
fused_rotary_position_embedding = None
24+
25+
try:
26+
from paddle.incubate.nn.functional import swiglu
27+
except ImportError:
28+
29+
def swiglu(x, y=None):
30+
if y is None:
31+
x, y = paddle.chunk(x, chunks=2, axis=-1)
32+
return F.silu(x) * y
33+
34+
35+
from paddle.utils import try_import
36+
37+
from paddlenlp.utils.tools import get_env_device
38+
39+
try:
40+
from paddle.incubate.nn.functional import fused_rotary_position_embedding
41+
except ImportError:
42+
fused_rotary_position_embedding = None
43+
try:
44+
if get_env_device() == "npu":
45+
from paddle.base import core
46+
47+
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
48+
if lib.endswith(".so"):
49+
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
50+
from paddle.nn.functional.flash_attention import flash_attention
51+
except:
52+
flash_attention = None
53+
54+
55+
def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb):
56+
assert past_key_value is None, "fuse rotary not support cache kv for now"
57+
batch_size, seq_length, num_heads, head_dim = query_states.shape
58+
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
59+
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
60+
if get_env_device() == "npu":
61+
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
62+
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
63+
else:
64+
# paddle version > 2.6 or develop support q and k/v with different num_heads
65+
paddle_version = float(paddle.__version__[:3])
66+
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
67+
query_states, _, _ = fused_rotary_position_embedding(
68+
query_states,
69+
None,
70+
None,
71+
sin=sin,
72+
cos=cos,
73+
position_ids=position_ids,
74+
use_neox_rotary_style=False,
75+
)
76+
key_states, _, _ = fused_rotary_position_embedding(
77+
key_states,
78+
None,
79+
None,
80+
sin=sin,
81+
cos=cos,
82+
position_ids=position_ids,
83+
use_neox_rotary_style=False,
84+
)
85+
else:
86+
query_states, key_states, _ = fused_rotary_position_embedding(
87+
query_states,
88+
key_states,
89+
v=None,
90+
sin=sin,
91+
cos=cos,
92+
position_ids=position_ids,
93+
use_neox_rotary_style=False,
94+
)
95+
return query_states, key_states
96+
97+
98+
def rms_norm_fused(x_in, w, eps):
99+
fused_ln = try_import("fused_ln")
100+
return fused_ln.fused_rms_norm(x_in, w, eps)[0]
101+
102+
103+
def fusion_rms_norm(hidden_states, weight, variance_epsilon):
104+
if get_env_device() == "npu":
105+
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
106+
elif get_env_device() == "xpu":
107+
try:
108+
import paddle_xpu_nn # noqa: F821
109+
110+
return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0]
111+
except ImportError:
112+
raise NotImplementedError(
113+
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
114+
)
115+
return rms_norm_fused(hidden_states, weight, variance_epsilon)
116+
117+
118+
def fusion_flash_attention(
119+
query_states,
120+
config,
121+
key_states,
122+
value_states,
123+
attention_mask,
124+
output_attentions,
125+
alibi=None,
126+
sequence_parallel=False,
127+
reshard_layer=None,
128+
npu_is_casual=False,
129+
):
130+
bsz, q_len, num_heads, head_dim = query_states.shape
131+
_, kv_seq_len, _, _ = value_states.shape
132+
version = paddle.version.full_version
133+
if version != "0.0.0" and version <= "2.5.2":
134+
if alibi is not None:
135+
raise ValueError("Flash Attention doesn't support alibi")
136+
attn_output, attn_weights = flash_attention(
137+
query_states,
138+
key_states,
139+
value_states,
140+
causal=True,
141+
return_softmax=output_attentions,
142+
)
143+
else:
144+
if alibi is not None:
145+
alibi = alibi.reshape([bsz, num_heads, 1, -1])
146+
attention_mask = attention_mask.cast(alibi.dtype) + alibi
147+
if get_env_device() == "npu":
148+
attn_output = core.eager._run_custom_op(
149+
"flash_attention_npu",
150+
query_states,
151+
key_states,
152+
value_states,
153+
None,
154+
attention_mask,
155+
0.0,
156+
attention_mask is None,
157+
True,
158+
False,
159+
npu_is_casual,
160+
)[0]
161+
else:
162+
attn_output = F.scaled_dot_product_attention(
163+
query_states,
164+
key_states,
165+
value_states,
166+
attn_mask=attention_mask,
167+
is_causal=attention_mask is None,
168+
)
169+
attn_weights = None
170+
171+
if reshard_layer is not None:
172+
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
173+
attn_output = reshard_layer(
174+
attn_output,
175+
split_axis=1,
176+
concat_axis=2,
177+
)
178+
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
179+
assert (
180+
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
181+
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
182+
q_len = q_len // config.sep_parallel_degree
183+
num_heads = num_heads * config.sep_parallel_degree
184+
185+
if sequence_parallel:
186+
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
187+
else:
188+
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
189+
return (attn_output, attn_weights) if output_attentions else attn_output

paddlenlp/transformers/llama/modeling.py

Lines changed: 27 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def swiglu(x, y=None):
5555
)
5656
except:
5757
pass
58-
from paddle.utils import try_import
5958

6059
from paddlenlp.transformers.conversion_utils import (
6160
StateDictNameMapping,
@@ -81,14 +80,16 @@ def swiglu(x, y=None):
8180

8281
try:
8382
if get_env_device() == "npu":
84-
from paddle.base import core
8583

8684
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
8785
if lib.endswith(".so"):
8886
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
8987
from paddle.nn.functional.flash_attention import flash_attention
9088
except:
9189
flash_attention = None
90+
from . import fusion_ops
91+
92+
rms_norm_fused = fusion_ops.rms_norm_fused
9293

9394
__all__ = [
9495
"LlamaModel",
@@ -215,67 +216,22 @@ def scaled_dot_product_attention(
215216
_, kv_seq_len, _, _ = value_states.shape
216217

217218
if config.use_flash_attention and flash_attention:
219+
return fusion_ops.fusion_flash_attention(
220+
query_states,
221+
config,
222+
key_states,
223+
value_states,
224+
attention_mask,
225+
output_attentions,
226+
alibi,
227+
sequence_parallel,
228+
reshard_layer,
229+
npu_is_casual,
230+
)
231+
218232
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219233
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
220234

221-
version = paddle.version.full_version
222-
if version != "0.0.0" and version <= "2.5.2":
223-
if alibi is not None:
224-
raise ValueError("Flash Attention doesn't support alibi")
225-
attn_output, attn_weights = flash_attention(
226-
query_states,
227-
key_states,
228-
value_states,
229-
causal=True,
230-
return_softmax=output_attentions,
231-
)
232-
else:
233-
if alibi is not None:
234-
alibi = alibi.reshape([bsz, num_heads, 1, -1])
235-
attention_mask = attention_mask.cast(alibi.dtype) + alibi
236-
if get_env_device() == "npu":
237-
attn_output = core.eager._run_custom_op(
238-
"flash_attention_npu",
239-
query_states,
240-
key_states,
241-
value_states,
242-
None,
243-
attention_mask,
244-
0.0,
245-
attention_mask is None,
246-
True,
247-
False,
248-
npu_is_casual,
249-
)[0]
250-
else:
251-
attn_output = F.scaled_dot_product_attention(
252-
query_states,
253-
key_states,
254-
value_states,
255-
attn_mask=attention_mask,
256-
is_causal=attention_mask is None,
257-
)
258-
attn_weights = None
259-
260-
if reshard_layer is not None:
261-
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
262-
attn_output = reshard_layer(
263-
attn_output,
264-
split_axis=1,
265-
concat_axis=2,
266-
)
267-
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
268-
assert (
269-
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
270-
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
271-
q_len = q_len // config.sep_parallel_degree
272-
num_heads = num_heads * config.sep_parallel_degree
273-
274-
if sequence_parallel:
275-
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
276-
else:
277-
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
278-
return (attn_output, attn_weights) if output_attentions else attn_output
279235
else:
280236
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
281237
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
@@ -385,11 +341,6 @@ def _expand_2d_mask(mask, dtype, tgt_length):
385341
return expanded_mask
386342

387343

388-
def rms_norm_fused(x_in, w, eps):
389-
fused_ln = try_import("fused_ln")
390-
return fused_ln.fused_rms_norm(x_in, w, eps)[0]
391-
392-
393344
class LlamaRMSNorm(nn.Layer):
394345
def __init__(self, config):
395346
super().__init__()
@@ -407,18 +358,7 @@ def __init__(self, config):
407358

408359
def forward(self, hidden_states):
409360
if self.config.use_fused_rms_norm:
410-
if get_env_device() == "npu":
411-
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
412-
elif get_env_device() == "xpu":
413-
try:
414-
import paddle_xpu_nn # noqa: F821
415-
416-
return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
417-
except ImportError:
418-
raise NotImplementedError(
419-
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
420-
)
421-
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
361+
return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon)
422362

423363
if paddle.in_dynamic_mode():
424364
with paddle.amp.auto_cast(False):
@@ -974,45 +914,16 @@ def forward(
974914
batch_size, seq_length, _, _ = query_states.shape
975915
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
976916
if self.use_fused_rope:
977-
assert past_key_value is None, "fuse rotary not support cache kv for now"
978-
batch_size, seq_length, num_heads, head_dim = query_states.shape
979-
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
980-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
981-
if get_env_device() == "npu":
982-
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
983-
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
984-
else:
985-
# paddle version > 2.6 or develop support q and k/v with different num_heads
986-
paddle_version = float(paddle.__version__[:3])
987-
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
988-
query_states, _, _ = fused_rotary_position_embedding(
989-
query_states,
990-
None,
991-
None,
992-
sin=sin,
993-
cos=cos,
994-
position_ids=position_ids,
995-
use_neox_rotary_style=False,
996-
)
997-
key_states, _, _ = fused_rotary_position_embedding(
998-
key_states,
999-
None,
1000-
None,
1001-
sin=sin,
1002-
cos=cos,
1003-
position_ids=position_ids,
1004-
use_neox_rotary_style=False,
1005-
)
1006-
else:
1007-
query_states, key_states, _ = fused_rotary_position_embedding(
1008-
query_states,
1009-
key_states,
1010-
v=None,
1011-
sin=sin,
1012-
cos=cos,
1013-
position_ids=position_ids,
1014-
use_neox_rotary_style=False,
1015-
)
917+
query_states, key_states = fusion_ops.fusion_rope(
918+
query_states,
919+
key_states,
920+
value_states,
921+
hidden_states,
922+
position_ids,
923+
past_key_value,
924+
self.rotary_emb,
925+
)
926+
1016927
else:
1017928
if self.config.use_long_sequence_strategies:
1018929
cos, sin = self.rotary_emb(seq_len=kv_seq_len)

0 commit comments

Comments
 (0)