Skip to content

xpu devices support llama-7b basic mode inference (turn on BlockAtten… #8588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llm/docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ PaddleNLP 针对于Transformer 系列编写了高性能自定义算子,提升

```shell
git clone https://github.com/PaddlePaddle/PaddleNLP
#GPU设备安装自定义算子
cd ./paddlenlp/csrc && python setup_cuda.py install
#XPU设备安装自定义算子
cd ./paddlenlp/csrc/xpu/src && sh cmake_build.sh
```

### 2.3 关闭BlockAttention的高性能推理
Expand Down Expand Up @@ -163,6 +166,9 @@ python predictor.py --model_name_or_path ./inference --inference_model --quant_
# 动态图模型推理命令参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn

# XPU设备动态图模型推理命令参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn --device xpu

# Weight Only Int8 动态图推理参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --quant_type weight_only_int8 --block_attn

Expand All @@ -179,6 +185,9 @@ python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_
# 动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn

# XPU设备动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn --device xpu

# Weight Only Int8 动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --quant_type weight_only_int8 --block_attn

Expand All @@ -194,6 +203,9 @@ python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --infere
# 静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn

# XPU设备静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn --device xpu

# Weight Only Int8 静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --quant_type weight_only_int8 --block_attn

Expand Down
27 changes: 25 additions & 2 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,11 @@ def _create_predictor(self, predictor_args: PredictorArgument):
if predictor_args.device in paddle.device.get_all_custom_device_type():
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
config.enable_custom_device(predictor_args.device, device_id)
elif predictor_args.device == "xpu":
raise ValueError(
"you should export xpu static model with --block_attn flag and use predictor with --block_attn too"
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md"
)
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
Expand Down Expand Up @@ -920,7 +925,9 @@ def _preprocess(self, source):
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]

for i, text in enumerate(source):
add_special_tokens = self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer))
add_special_tokens = self.tokenizer.chat_template is None or isinstance(
self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)
)
add_special_tokens = add_special_tokens if not self.benchmark else False
tokens = self.tokenizer(
text,
Expand Down Expand Up @@ -1076,6 +1083,15 @@ def _create_predictor(self, predictor_args: PredictorArgument):
if predictor_args.device in paddle.device.get_all_custom_device_type():
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
config.enable_custom_device(predictor_args.device, device_id)
elif predictor_args.device == "xpu":
config.enable_xpu()
device_id = int(os.environ.get("FLAGS_selected_xpus", 0))
config.set_xpu_device_id(device_id)
xpu_config = paddle.inference.XpuConfig()
xpu_config.device_id = device_id
xpu_config.l3_size = 63 * 1024 * 1024
xpu_config.l3_autotune_size = 63 * 1024 * 1024
config.set_xpu_config(xpu_config)
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
Expand Down Expand Up @@ -1331,6 +1347,11 @@ def create_predictor(
tensor_parallel_rank=tensor_parallel_rank,
)
else:
if predictor_args.device == "xpu":
raise ValueError(
"you should run xpu dynamic model with --block_attn flag"
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md"
)
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
)
Expand Down Expand Up @@ -1588,7 +1609,9 @@ def predict():

def benchmark(predictor, predictor_args, model_args):
# Just construct a simple benchmark input. We pad input to the src_length.
benchmark_texts = [predictor.tokenizer.pad_token * predictor_args.src_length for _ in range(predictor_args.batch_size)]
benchmark_texts = [
predictor.tokenizer.pad_token * predictor_args.src_length for _ in range(predictor_args.batch_size)
]

batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size)
print("***********Start Benchmark**********")
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/experimental/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset, get_padding_offset_v2

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedBlockMultiTransformer,
Expand Down Expand Up @@ -219,6 +218,8 @@
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

Check warning on line 221 in paddlenlp/experimental/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/bloom/modeling.py#L221

Added line #L221 was not covered by tests

ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down Expand Up @@ -592,6 +593,8 @@
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset_v2

Check warning on line 596 in paddlenlp/experimental/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/bloom/modeling.py#L596

Added line #L596 was not covered by tests

ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/experimental/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from paddle import nn
from paddle.distributed import fleet
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformerConfig,
Expand Down Expand Up @@ -273,6 +272,8 @@
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

Check warning on line 275 in paddlenlp/experimental/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm/modeling.py#L275

Added line #L275 was not covered by tests

ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/experimental/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import paddle.distributed.fleet as fleet
import paddle.nn as nn
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformerBase,
Expand Down Expand Up @@ -202,6 +201,8 @@
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

Check warning on line 204 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L204

Added line #L204 was not covered by tests

ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
133 changes: 87 additions & 46 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import paddle
import paddle.distributed as dist
from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.framework import LayerHelper, core, in_dynamic_mode

Check warning on line 18 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L18

Added line #L18 was not covered by tests
from paddle.incubate.nn.functional import (
fused_layer_norm,
fused_rms_norm,
Expand All @@ -29,23 +29,24 @@
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
from paddlenlp.utils.log import logger

if is_paddlenlp_ops_available():
if not is_paddlenlp_ops_available():
logger.warning(

Check warning on line 33 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L32-L33

Added lines #L32 - L33 were not covered by tests
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)

from paddlenlp_ops import rebuild_padding_v2

Check warning on line 38 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L38

Added line #L38 was not covered by tests

if core.is_compiled_with_cuda():

Check warning on line 40 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L40

Added line #L40 was not covered by tests
from paddlenlp_ops import (
dequant_int8,
encode_rotary_qk,
qkv_transpose_split,
quant_int8,
rebuild_padding,
rebuild_padding_v2,
transpose_remove_padding,
write_cache_kv,
)
else:
logger.warning(
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)


__all__ = [
"FusedMultiTransformerConfig",
Expand Down Expand Up @@ -1348,6 +1349,9 @@
class FusedBlockMultiTransformer(FusedMultiTransformerBase):
def __init__(self, config: FusedMultiTransformerConfig):
super().__init__(config)
if not core.is_compiled_with_cuda():
self.cache_k_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype="float32")
self.cache_v_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype="float32")

Check warning on line 1354 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1352-L1354

Added lines #L1352 - L1354 were not covered by tests

def compute_attn(
self,
Expand Down Expand Up @@ -1375,43 +1379,80 @@
v_quant_scales = self.cache_v_scales
k_dequant_scales = self.cache_k_out_scales
v_dequant_scales = self.cache_v_out_scales

fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]

if not core.is_compiled_with_cuda():
fmha_out = paddle.incubate.nn.functional.block_multihead_attention_xpu(

Check warning on line 1383 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1382-L1383

Added lines #L1382 - L1383 were not covered by tests
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
self.cache_k_per_batch_maxs,
self.cache_v_per_batch_maxs,
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]
else:
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(

Check warning on line 1421 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1421

Added line #L1421 was not covered by tests
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]
out_linear_out = self.compute_out_linear(fmha_out, i)

return out_linear_out
Expand Down
Loading
Loading