Skip to content

Commit bbf945b

Browse files
Galaxy1458dynamicheartDrownFish19ziangqin-baiduSylarTiaNII
authored
Release/2.8 (#8437)
* [XPU] llama add xpu support (#8282) * [XPU] llama add xpu support * fix * use try import * fix * refine * refine * refine * refine * update (#8399) * [LLM] Support fuse attention q, k, v weights (#8202) 1. add use-interface & fuse action 1.1. modify 1., code order 2. switch to name_mapping 3. solve tp branch 3.2 follow hui, handel qkv separately 3.3 handle pdparams 3.4 from torch 3.5 abandon low_cpu_mem_usage 3.6 solve shard branch * 3.6.1 solve shard branch after rebase develop * code clean * remove debug comment * Redefine fuse and split functions * Redefine fuse and split functions * comment and fix * update method * update QKV fuse and split * support fuse weights in multi-files * add precision compare * simplify function call * support use_fast_ffn * clean modeling and configuration * add test for gpt and opt * fix tp_actions get * add fast_ffn test * add Qwen2Moe * Revert "add Qwen2Moe" This reverts commit 113b883. * add test for split * update doc * update filter_dict_keys --------- Co-authored-by: Zii <[email protected]> * [LLM] Fix fuse or split with same key (#8378) * fix fuse or split with same key * fix * fix eps * update format * [LLM] add decay steps option for finetuning (#8251) * [LLM] add memory stats to logger of trainer (#8269) * [Distributed] fix lora (#8325) * [LLM] fix lora target modules on llama (#8372) * [Distributed] metric calculation supports tp logits (#8370) * Update model_utils.py * Update model_utils.py * Update model_utils.py --------- Co-authored-by: Jianbang Yang <[email protected]> Co-authored-by: DrownFish19 <[email protected]> Co-authored-by: Zii <[email protected]> Co-authored-by: Tian <[email protected]>
1 parent 8879f79 commit bbf945b

File tree

16 files changed

+886
-47
lines changed

16 files changed

+886
-47
lines changed

llm/finetune_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def main():
140140
if not training_args.autotuner_benchmark:
141141
model = AutoModelForCausalLMPipe.from_pretrained(
142142
model_args.model_name_or_path,
143-
tensor_parallel_output=False,
143+
tensor_parallel_output=training_args.tensor_parallel_output,
144144
tensor_parallel_degree=training_args.tensor_parallel_degree,
145145
tensor_parallel_rank=training_args.tensor_parallel_rank,
146146
use_flash_attention=model_args.use_flash_attention,
@@ -152,7 +152,7 @@ def main():
152152
# NOTE(gongenlei): new add autotuner_benchmark
153153
model_config = AutoConfig.from_pretrained(
154154
model_args.model_name_or_path,
155-
tensor_parallel_output=False,
155+
tensor_parallel_output=training_args.tensor_parallel_output,
156156
tensor_parallel_degree=training_args.tensor_parallel_degree,
157157
tensor_parallel_rank=training_args.tensor_parallel_rank,
158158
dtype=dtype,
@@ -163,7 +163,7 @@ def main():
163163
else:
164164
model_config = AutoConfig.from_pretrained(
165165
model_args.model_name_or_path,
166-
tensor_parallel_output=False,
166+
tensor_parallel_output=training_args.tensor_parallel_output,
167167
tensor_parallel_degree=training_args.tensor_parallel_degree,
168168
tensor_parallel_rank=training_args.tensor_parallel_rank,
169169
dtype=dtype,

llm/run_pretrain.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
4848
from paddlenlp.utils.log import logger
49+
from paddlenlp.utils.tools import get_env_device
4950

5051

5152
def add_start_docstrings(*docstr):
@@ -483,6 +484,16 @@ def main():
483484
config.num_attention_heads % config.sep_parallel_degree == 0
484485
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"
485486

487+
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
488+
try:
489+
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
490+
491+
LinearConfig.enable_accumulate_steps_opt()
492+
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
493+
except ImportError:
494+
# It's OK, not use accumulate_steps optimization
495+
pass
496+
486497
print("Final pre-training config:", config)
487498

488499
# Set the dtype for loading model

llm/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,11 @@ def get_lora_target_modules(model):
125125
".*v_proj.*",
126126
".*k_proj.*",
127127
".*o_proj.*",
128+
".*qkv_proj.*",
128129
".*gate_proj.*",
129130
".*down_proj.*",
130131
".*up_proj.*",
132+
".*gate_up_fused_proj.*",
131133
]
132134
elif model.base_model_prefix == "opt":
133135
target_modules = [
@@ -209,6 +211,13 @@ def prediction_step(
209211
# keepdim in order to maintain the same shape as logits
210212
if isinstance(logits, (list, tuple)):
211213
logits = logits[0]
214+
# all gather logits when enabling tensor_parallel_output
215+
if self.args.tensor_parallel_degree > 1 and self.args.tensor_parallel_output:
216+
hcg = fleet.get_hybrid_communicate_group()
217+
model_parallel_group = hcg.get_model_parallel_group()
218+
gathered_logits = []
219+
dist.all_gather(gathered_logits, logits, group=model_parallel_group)
220+
logits = paddle.concat(gathered_logits, axis=-1)
212221
return (loss, logits.argmax(axis=-1, keepdim=True), labels)
213222

214223
loss = None

paddlenlp/peft/lora/lora_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def forward(self, input: paddle.Tensor):
539539
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
540540
else:
541541
res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group)
542-
result_mp = res_mp + self.bias
542+
result_mp = (res_mp + self.bias) if self.bias is not None else res_mp
543543

544544
if not self.merged:
545545
input_a = self.lora_dropout(input) @ self.lora_A

paddlenlp/trainer/trainer.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import paddle.distributed as dist
4040
import paddle.nn as nn
4141
from packaging import version
42+
from paddle import framework
43+
from paddle.base import core
4244
from paddle.distributed import fleet
4345
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
4446
HybridParallelOptimizer,
@@ -1257,6 +1259,20 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
12571259
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
12581260
logs["global_step"] = int(self.state.global_step)
12591261

1262+
divisor = 2**30
1263+
# TODO(@gexiao): replace these codes with unified APIs in Paddle
1264+
current_device = framework._current_expected_place_()
1265+
if str(current_device) != "Place(cpu)":
1266+
device_id = current_device.get_device_id()
1267+
current_memory_allocated = core.device_memory_stat_current_value("Allocated", device_id)
1268+
current_memory_reserved = core.device_memory_stat_current_value("Reserved", device_id)
1269+
max_memory_allocated = core.device_memory_stat_peak_value("Allocated", device_id)
1270+
max_memory_reserved = core.device_memory_stat_peak_value("Reserved", device_id)
1271+
logs["current_memory_allocated"] = current_memory_allocated / divisor
1272+
logs["current_memory_reserved"] = current_memory_reserved / divisor
1273+
logs["max_memory_allocated"] = max_memory_allocated / divisor
1274+
logs["max_memory_reserved"] = max_memory_reserved / divisor
1275+
12601276
total_train_batch_size = (
12611277
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
12621278
)
@@ -1614,8 +1630,6 @@ def _load_rng_state(self, checkpoint):
16141630
random.setstate(checkpoint_rng_state["python"])
16151631
np.random.set_state(checkpoint_rng_state["numpy"])
16161632

1617-
core = paddle.framework.core
1618-
16191633
core.default_cpu_generator().set_state(checkpoint_rng_state["cpu"])
16201634
if core.is_compiled_with_cuda():
16211635
if not len(checkpoint_rng_state["cuda"]) == core.get_cuda_device_count():

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,10 @@ class TrainingArguments:
787787
default=False,
788788
metadata={"help": "whether to run distributed training in auto parallel mode"},
789789
)
790+
tensor_parallel_output: Optional[bool] = field(
791+
default=False,
792+
metadata={"help": "whether to output logits in distributed status"},
793+
)
790794

791795
def __post_init__(self):
792796
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))

0 commit comments

Comments
 (0)