Skip to content

[LLM] Support prefix tuning and lora for qwen2 #8601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_prefix_tuning_params(model):
num_hidden_layers = model.config.num_layers
hidden_size = model.config.hidden_size
postprocess_past_key_value = chatglm_postprocess_past_key_value
multi_query_group_num = model.config.multi_query_group_num
multi_query_group_num = model.config.multi_query_group_num # num_key_value_heads
elif model.base_model_prefix == "bloom":
from paddlenlp.peft.prefix import bloom_postprocess_past_key_value

Expand All @@ -92,6 +92,14 @@ def get_prefix_tuning_params(model):
hidden_size = model.config.hidden_size
postprocess_past_key_value = qwen_postprocess_past_key_value
multi_query_group_num = None
elif model.base_model_prefix == "qwen2":
from paddlenlp.peft.prefix import qwen_postprocess_past_key_value

num_attention_heads = model.config.num_attention_heads
num_hidden_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
postprocess_past_key_value = qwen_postprocess_past_key_value
multi_query_group_num = model.config.num_key_value_heads # num_key_value_heads
else:
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}. ")
return dict(
Expand Down Expand Up @@ -150,6 +158,16 @@ def get_lora_target_modules(model):
".*mlp.w2.*",
".*mlp.c_proj.*",
]
elif model.base_model_prefix == "qwen2":
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
".*gate_proj.*",
".*down_proj.*",
".*up_proj.*",
]
elif model.base_model_prefix == "mixtral":
target_modules = [
".*q_proj.*",
Expand Down
11 changes: 11 additions & 0 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,17 @@
)
pass

# Note:
# 1. PipelineLayer will create parameters for each layer and
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
# synchronize the shared parameters.
# However, when state dict only contains the one piece of shared parameters, the shared parameters
# will be different from the original shared parameters.

if isinstance(model, PipelineLayer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞👍🏻

model._synchronize_shared_weights()

Check warning on line 2368 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2368

Added line #L2368 was not covered by tests

if paddle.in_dynamic_mode():
return model

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/qwen2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .configuration import *
from .modeling import *
from .modeling_pp import *
from .tokenizer import *
3 changes: 3 additions & 0 deletions paddlenlp/transformers/qwen2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def __init__(
self.eos_token_id = eos_token_id

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
105 changes: 68 additions & 37 deletions paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
"Qwen2PretrainedModel",
"Qwen2ForCausalLM",
"Qwen2PretrainingCriterion",
"Qwen2ForSequenceClassification",
"Qwen2ForTokenClassification",
]


Expand Down Expand Up @@ -112,7 +114,7 @@
return assignment_list


def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
def parallel_matmul(x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True):
is_fleet_init = True
tensor_parallel_degree = 1
try:
Expand All @@ -130,15 +132,15 @@
if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
logits = paddle.matmul(input_parallel, y, transpose_y=False)
logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)

Check warning on line 135 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L135

Added line #L135 was not covered by tests

if tensor_parallel_output:
return logits

return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)

else:
logits = paddle.matmul(x, y, transpose_y=False)
logits = paddle.matmul(x, y, transpose_y=transpose_y)
return logits


Expand Down Expand Up @@ -291,12 +293,10 @@
def forward(self, hidden_states):
if paddle.in_dynamic_mode():
with paddle.amp.auto_cast(False):
hidden_states = hidden_states.astype("float32")
variance = hidden_states.pow(2).mean(-1, keepdim=True)
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
else:
hidden_states = hidden_states.astype("float32")
variance = hidden_states.pow(2).mean(-1, keepdim=True)
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)

Check warning on line 299 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L299

Added line #L299 was not covered by tests
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states

if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
Expand Down Expand Up @@ -480,12 +480,8 @@

if config.tensor_parallel_degree > 1:
self.q_proj = ColumnParallelLinear(self.hidden_size, self.hidden_size, has_bias=True, gather_output=False)
self.k_proj = ColumnParallelLinear(
self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False
)
self.v_proj = ColumnParallelLinear(
self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False
)
self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip

Check warning on line 484 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L483-L484

Added lines #L483 - L484 were not covered by tests
self.o_proj = RowParallelLinear(self.hidden_size, self.hidden_size, has_bias=False, input_is_parallel=True)
else:
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias_attr=True)
Expand All @@ -512,8 +508,6 @@
"""Input shape: Batch x Time x Channel"""
# [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)

batch_size, seq_len, _ = hidden_states.shape

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
Expand Down Expand Up @@ -617,6 +611,7 @@
class Qwen2DecoderLayer(nn.Layer):
def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = Qwen2Attention(config, layerwise_recompute)

Expand Down Expand Up @@ -757,7 +752,8 @@
for mapping in model_mappings:
mapping[0] = "model." + mapping[0]
mapping[1] = "qwen2." + mapping[1]
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])
if not config.tie_word_embeddings:
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])

mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
return mappings
Expand All @@ -777,11 +773,14 @@
final_actions = {}

base_actions = {
"lm_head.weight": partial(fn, is_column=True),
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
}
if config.tie_word_embeddings:
base_actions["lm_head.weight"] = partial(fn, is_column=False)

Check warning on line 781 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L780-L781

Added lines #L780 - L781 were not covered by tests
else:
base_actions["lm_head.weight"] = partial(fn, is_column=True)

Check warning on line 783 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L783

Added line #L783 was not covered by tests

if not config.vocab_size % config.tensor_parallel_degree == 0:
base_actions.pop("lm_head.weight")
Expand Down Expand Up @@ -985,12 +984,10 @@
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
Expand All @@ -1014,6 +1011,7 @@
cache_length = past_key_values[0][0].shape[1]
seq_length_with_past += cache_length
if inputs_embeds is None:
# [bs, seq_len, dim]
inputs_embeds = self.embed_tokens(input_ids)

if self.sequence_parallel:
Expand Down Expand Up @@ -1143,22 +1141,41 @@


class Qwen2LMHead(nn.Layer):
def __init__(self, config: Qwen2Config):
def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=False):
super(Qwen2LMHead, self).__init__()
self.config = config
if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0:
vocab_size = config.vocab_size // config.tensor_parallel_degree
else:
vocab_size = config.vocab_size

self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
)
self.transpose_y = transpose_y
if transpose_y:
if embedding_weights is not None:
self.weight = embedding_weights
else:
self.weight = self.create_parameter(

Check warning on line 1157 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L1157

Added line #L1157 was not covered by tests
shape=[vocab_size, config.hidden_size],
dtype=paddle.get_default_dtype(),
)
else:
if vocab_size != config.vocab_size:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(

Check warning on line 1164 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L1163-L1164

Added lines #L1163 - L1164 were not covered by tests
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
)
else:
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
)

# Must set distributed attr for Tensor Parallel !
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if self.weight.is_distributed:
self.weight.split_axis = 1
# for tie_word_embeddings
self.weight.split_axis = 0 if self.transpose_y else 1

Check warning on line 1178 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L1178

Added line #L1178 was not covered by tests

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1169,7 +1186,9 @@
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
logits = parallel_matmul(
hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output
)
return logits


Expand All @@ -1180,7 +1199,11 @@
def __init__(self, config: Qwen2Config):
super().__init__(config)
self.qwen2 = Qwen2Model(config)
self.lm_head = Qwen2LMHead(config)
if config.tie_word_embeddings:
self.lm_head = Qwen2LMHead(config, embedding_weights=self.qwen2.embed_tokens.weight, transpose_y=True)
self.tie_weights()
else:
self.lm_head = Qwen2LMHead(config)
self.criterion = Qwen2PretrainingCriterion(config)
self.vocab_size = config.vocab_size

Expand Down Expand Up @@ -1250,10 +1273,18 @@
model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1)

if not is_encoder_decoder and "attention_mask" in model_kwargs:
# TODO: support attention mask for other models
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = paddle.concat(
[attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1
)
if len(attention_mask.shape) == 2:
model_kwargs["attention_mask"] = paddle.concat(

Check warning on line 1279 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L1279

Added line #L1279 was not covered by tests
[attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)],
axis=-1,
)
elif len(attention_mask.shape) == 4:
model_kwargs["attention_mask"] = paddle.concat(
[attention_mask, paddle.ones([*attention_mask.shape[:3], 1], dtype=attention_mask.dtype)],
axis=-1,
)[:, :, -1:, :]

return model_kwargs

Expand Down Expand Up @@ -1347,7 +1378,7 @@
super().__init__(config)
self.num_labels = config.num_labels
self.qwen2 = Qwen2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias_attr=False)

def get_input_embeddings(self):
return self.qwen2.embed_tokens
Expand Down Expand Up @@ -1402,17 +1433,17 @@
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = paddle.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = paddle.equal(input_ids, self.config.pad_token_id).astype("int32").argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
sequence_lengths = sequence_lengths
else:
sequence_lengths = -1

pooled_logits = logits[paddle.arange(batch_size, device=logits.device), sequence_lengths]
# pooled_logits = logits[paddle.arange(batch_size), sequence_lengths]
pooled_logits = logits.gather_nd(paddle.stack([paddle.arange(logits.shape[0]), sequence_lengths], axis=-1))

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
Expand Down
Loading
Loading