diff --git a/llm/data.py b/llm/data.py index 5d44c72c8abd..767bd5a88a29 100644 --- a/llm/data.py +++ b/llm/data.py @@ -44,11 +44,11 @@ def get_convert_example(model): if base_model_prefix == "chatglm": return convert_example_chatglm - elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral"]: + elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]: return convert_example_common else: raise ValueError( - f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral" + f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma" ) diff --git a/llm/gemma/README.md b/llm/gemma/README.md new file mode 100644 index 000000000000..040e68f00132 --- /dev/null +++ b/llm/gemma/README.md @@ -0,0 +1,18 @@ +# Gemma + +## 1.模型介绍 + +[Gemma](https://blog.google/technology/developers/gemma-open-models/) 由谷歌DeepMind和谷歌其他团队开发,是一个轻量级、最先进的开放式模型家族,采用与Gemini模型相同的研究和技术构建。 + +**支持模型权重:** + +| Model | +| ------------------ | +| google/gemma-7b | +| google/gemma-7b-it | +| google/gemma-2b | +| google/gemma-2b-it | + +## 2. 模型精调 + +请参考[LLM全流程工具介绍](../README.md) diff --git a/llm/gemma/sft_argument.json b/llm/gemma/sft_argument.json new file mode 100644 index 000000000000..45a483d7e52a --- /dev/null +++ b/llm/gemma/sft_argument.json @@ -0,0 +1,30 @@ +{ + "model_name_or_path": "google/gemma-2b/", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/gemma_sft_ckpts", + "per_device_train_batch_size": 2, + "gradient_accumulation_steps": 1, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-05, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 512, + "max_length": 1024, + "fp16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 2, + "zero_padding": false, + "use_flash_attention": false + } \ No newline at end of file diff --git a/llm/gemma/sft_argument_7b.json b/llm/gemma/sft_argument_7b.json new file mode 100644 index 000000000000..16eba55bed9e --- /dev/null +++ b/llm/gemma/sft_argument_7b.json @@ -0,0 +1,32 @@ +{ + "model_name_or_path": "google/gemma-7b", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/gemma_sft_ckpts", + "per_device_train_batch_size": 8, + "gradient_accumulation_steps": 1, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":1, + "num_train_epochs": 3, + "learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 512, + "max_length": 1024, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "do_predict": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 8, + "pipeline_parallel_degree": 1, + "zero_padding": false, + "use_flash_attention": false +} \ No newline at end of file diff --git a/llm/gemma/sft_argument_7b_sharding.json b/llm/gemma/sft_argument_7b_sharding.json new file mode 100644 index 000000000000..ca04affdb243 --- /dev/null +++ b/llm/gemma/sft_argument_7b_sharding.json @@ -0,0 +1,33 @@ +{ + "model_name_or_path": "google/gemma-7b", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/llama_sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":1, + "num_train_epochs": 3, + "learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "fp16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "do_predict": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "sharding_parallel_degree": 8, + "sharding": "stage3", + "pipeline_parallel_degree": 1, + "zero_padding": false, + "use_flash_attention": false +} \ No newline at end of file diff --git a/llm/gemma/sft_argument_sharding.json b/llm/gemma/sft_argument_sharding.json new file mode 100644 index 000000000000..d462645e2235 --- /dev/null +++ b/llm/gemma/sft_argument_sharding.json @@ -0,0 +1,31 @@ +{ + "model_name_or_path": "google/gemma-2b/", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/chatglm2_sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "per_device_eval_batch_size": 1, + "eval_accumulation_steps":1, + "num_train_epochs": 3, + "learning_rate": 3e-05, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 512, + "max_length": 1024, + "fp16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "sharding_parallel_degree": 2, + "sharding": "stage3", + "zero_padding": false, + "use_flash_attention": false + } \ No newline at end of file diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index f58b3d837191..35da0ca03195 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -200,6 +200,7 @@ from .gau_alpha.modeling import * from .gau_alpha.tokenizer import * from .gau_alpha.configuration import * +from .gemma import * from .roformerv2.modeling import * from .roformerv2.tokenizer import * from .roformerv2.configuration import * diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index 9e291f83aa56..ed23b6758ca2 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -118,6 +118,7 @@ ("Bloom", "bloom"), ("QWen", "qwen"), ("Mixtral", "mixtral"), + ("Gemma", "gemma"), ] ) diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py index 451468741ea1..1e6e1215fe5d 100644 --- a/paddlenlp/transformers/auto/tokenizer.py +++ b/paddlenlp/transformers/auto/tokenizer.py @@ -97,6 +97,7 @@ ("BloomTokenizer", "bloom"), ("SpeechT5Tokenizer", "speecht5"), ("QWenTokenizer", "qwen"), + ("GemmaTokenizer", "gemma"), ] ) diff --git a/paddlenlp/transformers/gemma/__init__.py b/paddlenlp/transformers/gemma/__init__.py new file mode 100644 index 000000000000..af3692458d8a --- /dev/null +++ b/paddlenlp/transformers/gemma/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration import * +from .modeling import * +from .modeling_pp import * +from .tokenizer import * diff --git a/paddlenlp/transformers/gemma/configuration.py b/paddlenlp/transformers/gemma/configuration.py new file mode 100644 index 000000000000..efcbb4694fed --- /dev/null +++ b/paddlenlp/transformers/gemma/configuration.py @@ -0,0 +1,195 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Gemma model configuration""" + +from paddlenlp.transformers.configuration_utils import PretrainedConfig + +__all__ = [ + "GEMMA_PRETRAINED_INIT_CONFIGURATION", + "GemmaConfig", + "GEMMA_PRETRAINED_RESOURCE_FILES_MAP", +] + +GEMMA_PRETRAINED_INIT_CONFIGURATION = { + "google/gemma-2b": { + "architectures": ["GemmaForCausalLM"], + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 8192, + "model_type": "gemma", + "num_attention_heads": 8, + "num_key_value_heads": 1, + "num_hidden_layers": 28, + "rms_norm_eps": 1e-06, + "vocab_size": 256000, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "use_cache": True, + "use_recompute": False, + "use_flash_attention": False, + }, +} + + +GEMMA_PRETRAINED_RESOURCE_FILES_MAP = { + "model_state": { + "google/gemma-2b": "https://bj.bcebos.com/paddlenlp/models/community/google/gemma-2b/model.safetensors", + "google/gemma-2b-it": "https://bj.bcebos.com/paddlenlp/models/community/google/gemma-2b-it/model.safetensors", + }, +} + + +class GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~GemmaModel`]. It is used to instantiate a gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~GemmaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + use_fused_rope(`bool`, *optional*, defaults to False): + Enable rope fusion or not. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + Example: + ```python + >>> from paddlenlp.transformer import GemmaModel, GemmaModel + + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaModel() + + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "gemma" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu", + max_position_embeddings=8192, + seq_length=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + use_recompute=False, + recompute_granularity="full", + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + tensor_parallel_output=True, + sequence_parallel=False, + fuse_sequence_parallel_allreduce=False, + use_fused_rope=False, + fuse_attention_qkv=False, + fuse_attention_ffn=False, + alibi=False, + pp_recompute_interval=1, + no_recompute_layers=None, + use_flash_attention=False, + use_fused_rms_norm=False, + virtual_pp_degree=1, + rope_scaling_factor=1.0, + rope_scaling_type=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.seq_length = seq_length + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.tensor_parallel_output = tensor_parallel_output + self.sequence_parallel = sequence_parallel + self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce + self.use_fused_rope = use_fused_rope + self.fuse_attention_qkv = fuse_attention_qkv + self.fuse_attention_ffn = fuse_attention_ffn + self.alibi = alibi + self.pp_recompute_interval = pp_recompute_interval + self.no_recompute_layers = no_recompute_layers + self.use_flash_attention = use_flash_attention + self.use_fused_rms_norm = use_fused_rms_norm + self.virtual_pp_degree = virtual_pp_degree + self.rope_scaling_factor = rope_scaling_factor + self.rope_scaling_type = rope_scaling_type + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + tensor_parallel_output=tensor_parallel_output, + **kwargs, + ) + + @property + def rope(self): + return not self.alibi diff --git a/paddlenlp/transformers/gemma/modeling.py b/paddlenlp/transformers/gemma/modeling.py new file mode 100644 index 000000000000..c0e3debd9e81 --- /dev/null +++ b/paddlenlp/transformers/gemma/modeling.py @@ -0,0 +1,1534 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from functools import partial +from typing import List, Optional, Tuple + +import paddle +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.utils import recompute +from paddle.utils import try_import + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + + +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + GatherOp, + RowSequenceParallelLinear, + ScatterOp, + mark_as_sequence_parallel_parameter, +) + +from paddlenlp.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddlenlp.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model + +from ..segment_parallel_utils import ReshardLayer +from .configuration import ( + GEMMA_PRETRAINED_INIT_CONFIGURATION, + GEMMA_PRETRAINED_RESOURCE_FILES_MAP, + GemmaConfig, +) + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + + +def _get_interleave(n): + def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def rms_norm_fused(x_in, w, eps): + fused_ln = try_import("fused_ln") + return fused_ln.fused_rms_norm(x_in, w, eps)[0] + + +def assign_kv_heads(num_kv_heads: int, num_gpus: int): + # Initialize the assignment list + """ + Assign kv heads to different GPUs in the Tensor Parallel Setup + + Examples: + assign_kv_heads(num_kv_heads=1, num_gpus=2): [[0], [0]] + assign_kv_heads(num_kv_heads=2, num_gpus=2): [[0], [1]] + assign_kv_heads(num_kv_heads=4, num_gpus=2): [[0,1], [2,3]] + assign_kv_heads(num_kv_heads=1, num_gpus=4): [[0],[0],[0],[0]] + assign_kv_heads(num_kv_heads=2, num_gpus=4): [[0],[0],[1],[1]] + assign_kv_heads(num_kv_heads=4, num_gpus=4): [[0],[1],[2],[3]] + """ + assignment_list = [[] for _ in range(num_gpus)] + # Case 1: more heads than cards + if num_kv_heads > num_gpus: + num_heads_per_card = num_kv_heads // num_gpus + for i in range(num_gpus): + for j in range(num_heads_per_card): + assignment_list[i].append(i * num_heads_per_card + j) + # Case 2: more cards than heads. each card get only 1 head. + else: + num_card_per_heads = num_gpus // num_kv_heads + for i in range(num_kv_heads): + for j in range(num_card_per_heads): + assignment_list[i * num_card_per_heads + j].append(i) + return assignment_list + + +def build_alibi_tensor( + bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1 +) -> Tensor: + attention_mask = bool_attention_mask.astype("float32") + batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1] + slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32") + alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand( + [num_heads, -1, -1] + ) + alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1]) + return paddle.cast(alibi, dtype) + + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + + hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) + return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim]) + + +def parallel_matmul(x: Tensor, y, tensor_parallel_output=True, transpose_y=False): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = paddle.matmul(x, y, transpose_y=transpose_y) + return logits + + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi=None, + sequence_parallel=False, + reshard_layer=None, + attn_dropout_prob=0.0, + trainer_mode=False, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + + version = paddle.version.full_version + if version != "0.0.0" and version <= "2.5.2": + if alibi is not None: + raise ValueError("Flash Attention doesn't support alibi") + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + causal=True, + dropout=attn_dropout_prob, + return_softmax=output_attentions, + ) + else: + if alibi is not None: + alibi = alibi.reshape([bsz, num_heads, 1, -1]) + attention_mask = attention_mask.cast(alibi.dtype) + alibi + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + ) + attn_weights = None + + if reshard_layer is not None: + # attn_output shape: [bs, seqlen, num_head/sep, head_dim] + attn_output = reshard_layer( + attn_output, + split_axis=1, + concat_axis=2, + ) + # attn_output shape: [bs, seqlen/sep, num_head, head_dim] + assert ( + config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0 + ), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}" + q_len = q_len // config.sep_parallel_degree + num_heads = num_heads * config.sep_parallel_degree + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + else: + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next tranpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + # matmul and devide by sqrt(head_dim) + attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) + # then add alibi bias + if alibi is not None: + alibi = alibi.reshape([bsz, num_heads, 1, -1]) + attn_weights = attn_weights + alibi + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + # In sep mode, the attenion mask should be created in the runtime. + if reshard_layer is not None: + attention_mask = None + + # NOTE: we only call get_triangle_upper_mask under PP setup + # FIXME ZHUI when we use pipeline parallel, the attention_mask can be None + # we just make it triangle_upper_mask + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + if not paddle.in_dynamic_mode(): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + else: + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + attn_weights = F.dropout(attn_weights, attn_dropout_prob, training=trainer_mode) + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + if reshard_layer is not None: + attn_output = reshard_layer( + attn_output, + split_axis=1, + concat_axis=2, + ) + q_len = q_len // config.sep_parallel_degree + num_heads = num_heads * config.sep_parallel_degree + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + + +def is_casual_mask(attention_mask): + """ + Upper triangular of attention_mask equals to attention_mask is casual + """ + return (paddle.triu(attention_mask) == attention_mask).all().item() + + +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make causal mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +def _expand_2d_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + mask = mask[:, None, None, :].astype("bool") + mask.stop_gradient = True + expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length]) + + return expanded_mask + + +class GemmaRMSNorm(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + + def _norm(self, x): + return x * paddle.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) + + def forward(self, x): + if self.config.use_fused_rms_norm: + return rms_norm_fused(x, self.weight + 1, self.variance_epsilon) + + output = self._norm(x.astype(paddle.float32)).astype(x.dtype) + return output * (self.weight + 1) + + +class GemmaRotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + t = paddle.arange(seq_len, dtype="float32") + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + emb = paddle.concat([freqs, freqs], axis=-1) + return (emb.cos()[None, :, None, :].cast(dtype=x.dtype), emb.sin()[None, :, None, :].cast(dtype=x.dtype)) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + + if position_ids is None: + # Note: Only for ForCausalLMPipe model pretraining + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class GemmaMLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.tensor_parallel_degree = config.tensor_parallel_degree + + if config.sequence_parallel: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear + else: + ColumnParallelLinear = mpu.ColumnParallelLinear + RowParallelLinear = mpu.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + + def forward(self, x): + # GeGLU + out = self.down_proj(F.gelu(self.gate_proj(x)) * self.up_proj(x)) + return out + + +class GemmaAttention(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GemmaConfig, layerwise_recompute: bool = False): + super().__init__() + + self.config = config + self.attention_dropout = config.attention_dropout # add + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + + self.max_position_embeddings = config.max_position_embeddings + self.seq_length = config.seq_length + self.rope_theta = config.rope_theta + self.sequence_parallel = config.sequence_parallel + + self.kv_indices = None + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + + if self.num_key_value_heads % config.tensor_parallel_degree == 0: + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + else: + self.kv_indices = paddle.to_tensor( + assign_kv_heads(self.num_key_value_heads, config.tensor_parallel_degree)[ + config.tensor_parallel_rank + ] + ) + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope: + if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if config.sequence_parallel: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear + else: + ColumnParallelLinear = mpu.ColumnParallelLinear + RowParallelLinear = mpu.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.config.num_attention_heads * self.head_dim, + has_bias=config.attention_bias, + gather_output=False, + ) + if self.kv_indices is None: + # to revise shape + self.k_proj = ColumnParallelLinear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + has_bias=config.attention_bias, + gather_output=False, + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + has_bias=config.attention_bias, + gather_output=False, + ) + else: + self.k_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + + else: + self.q_proj = nn.Linear( + self.hidden_size, + self.config.num_attention_heads * self.head_dim, + bias_attr=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + + if config.tensor_parallel_degree > 1: + self.o_proj = RowParallelLinear( + self.config.num_attention_heads * self.head_dim, + self.hidden_size, + has_bias=False, + input_is_parallel=True, + ) + else: + self.o_proj = nn.Linear( + self.config.num_attention_heads * self.head_dim, + self.hidden_size, + bias_attr=False, + ) + self.rotary_emb = GemmaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.reshard_layer = None + if config.sep_parallel_degree > 1: + assert self.num_key_value_heads % config.sep_parallel_degree == 0 + assert self.num_heads % config.sep_parallel_degree == 0 + self.reshard_layer = ReshardLayer() + + self.config = config + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + alibi: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + if self.reshard_layer is not None: + if self.sequence_parallel: + assert self.seq_length % self.config.sep_parallel_degree == 0 + query_states = paddle.reshape( + query_states, + [-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim], + ) + key_states = paddle.reshape( + key_states, + [-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim], + ) + value_states = paddle.reshape( + value_states, + [-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim], + ) + query_states = self.reshard_layer( + query_states, + split_axis=2, + concat_axis=1, + ) + key_states = self.reshard_layer( + key_states, + split_axis=2, + concat_axis=1, + ) + value_states = self.reshard_layer( + value_states, + split_axis=2, + concat_axis=1, + ) + query_states = paddle.reshape( + query_states, [0, self.seq_length, -1, self.head_dim] + ) # [bs, seq_len, num_head/k, head_dim], k is sep degree + key_states = paddle.reshape(key_states, [0, self.seq_length, -1, self.head_dim]) + value_states = paddle.reshape(value_states, [0, self.seq_length, -1, self.head_dim]) + else: + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_key_value_heads, self.head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + query_states = query_states.reshape(shape=target_query_shape) + key_states = key_states.reshape(shape=target_key_value_shape) + value_states = value_states.reshape(shape=target_key_value_shape) + + kv_seq_len = key_states.shape[-3] + + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + + if self.config.rope: + if self.reshard_layer is not None: + batch_size, seq_length, _, _ = query_states.shape + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + if self.use_fused_rope: + assert past_key_value is None, "fuse rotary not support cache kv for now" + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + paddle_version = float(paddle.__version__[:3]) + if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and ( + self.num_heads != self.num_key_value_heads + ): + query_states, _, _ = fused_rotary_position_embedding( + query_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + key_states, _, _ = fused_rotary_position_embedding( + key_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = (key_states, value_states) if use_cache else None + + if self.kv_indices is not None: + key_states = paddle.index_select(key_states, self.kv_indices, axis=2) + value_states = paddle.index_select(value_states, self.kv_indices, axis=2) + key_states = paddle.broadcast_to(key_states, query_states.shape) + value_states = paddle.broadcast_to(value_states, query_states.shape) + else: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + scaled_dot_product_attention, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + self.sequence_parallel, + reshard_layer=self.reshard_layer, + use_reentrant=self.config.recompute_use_reentrant, + attn_dropout_prob=self.attention_dropout, + trainer_mode=self.training, + ) + else: + outputs = scaled_dot_product_attention( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + self.sequence_parallel, + reshard_layer=self.reshard_layer, + attn_dropout_prob=self.attention_dropout, + trainer_mode=self.training, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class GemmaDecoderLayer(nn.Layer): + def __init__(self, config, layerwise_recompute: bool = False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.self_attn = GemmaAttention(config, layerwise_recompute) + self.mlp = GemmaMLP(config) + self.input_layernorm = GemmaRMSNorm(config) + self.post_attention_layernorm = GemmaRMSNorm(config) + self.sequence_parallel = config.sequence_parallel + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + alibi: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `cache` key value states are returned and can be used to speed up decoding + (see `cache`). + cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + alibi, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + alibi, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + # remove empty tuple for pipeline parallel + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class GemmaPretrainedModel(PretrainedModel): + config_class = GemmaConfig + base_model_prefix = "gemma" + pretrained_init_configuration = GEMMA_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = GEMMA_PRETRAINED_RESOURCE_FILES_MAP + _keys_to_ignore_on_load_unexpected = [] + _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"] + + @classmethod + def _get_name_mappings(cls, config: GemmaConfig) -> List[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + init_name_mappings(mappings=model_mappings) + # base-model prefix "GemmaModel" + if "GemmaModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "gemma." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: GemmaConfig, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + # Column Linear + "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + def _init_weights(self, layer): + """Initialization hook""" + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.ColumnParallelLinear, + mpu.RowParallelLinear, + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.gemma.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.gemma.config.initializer_range, + shape=layer.weight.shape, + ) + ) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, GemmaMLP): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.down_proj.weight.scale_(factor) + if isinstance(layer, GemmaAttention): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + + +@register_base_model +class GemmaModel(GemmaPretrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`] + Args: + config: GemmaConfig + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.sequence_parallel = config.sequence_parallel + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + self.embed_tokens.weight.is_distributed = True + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + self.embed_tokens.weight.is_distributed = False + + self.layers = nn.LayerList( + [GemmaDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)] + ) + self.norm = GemmaRMSNorm(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + alibi=None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + **kwargs, + ): + if self.sequence_parallel and use_cache: + raise ValueError("We currently only support sequence parallel without cache.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if self.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = paddle.shape(past_key_values[0][0])[1] + seq_length_with_past += cache_length + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + if attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if self.config.alibi: + alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + if self.config.tensor_parallel_degree > 1: + block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree + alibi = alibi[ + :, + self.config.tensor_parallel_rank + * block_size : (self.config.tensor_parallel_rank + 1) + * block_size, + ] + alibi = alibi.reshape([batch_size * block_size, 1, seq_length_with_past]) + else: + alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length_with_past]) + else: + alibi = None + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + is_casual = is_casual_mask(attention_mask) + if is_casual and alibi is None: + attention_mask = None + + # embed positions + hidden_states = inputs_embeds + + # normalized + hidden_states = hidden_states * (self.config.hidden_size**0.5) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + ) + + +class GemmaPretrainingCriterion(nn.Layer): + """ + Criterion for gemma. Copied From Llama + It calculates the final loss. + """ + + def __init__(self, config): + + super().__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + + if self.config.sep_parallel_degree > 1: + _hcg = fleet.get_hybrid_communicate_group() + masked_lm_loss = ConcatSePMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) + # skip ignore_index which loss == 0 + masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] + loss = paddle.mean(masked_lm_loss) + + return loss + + +class ConcatSePMaskedLoss(PyLayer): + @staticmethod + def forward(ctx, inp, axis, group): + inputs = [] + paddle.distributed.all_gather(inputs, inp, group=group) + with paddle.no_grad(): + cat = paddle.concat(inputs, axis=axis) + ctx.args_axis = axis + ctx.args_group = group + return cat + + @staticmethod + def backward(ctx, grad): + axis = ctx.args_axis + group = ctx.args_group + with paddle.no_grad(): + grads = paddle.split(grad, paddle.distributed.get_world_size(group), axis=axis) + grad = grads[paddle.distributed.get_rank(group)] + return grad + + +class GemmaLMHead(nn.Layer): + def __init__(self, config: GemmaConfig): + super().__init__() + self.config = config + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + self.weight = self.create_parameter( + shape=[vocab_size, config.hidden_size] if config.tie_word_embeddings else [config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if self.weight.is_distributed: + self.weight.split_axis = 1 + + def forward(self, hidden_states, tensor_parallel_output=None): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + seq_length = self.config.seq_length + if self.config.sep_parallel_degree > 1: + assert seq_length % self.config.sep_parallel_degree == 0 + seq_length = seq_length // self.config.sep_parallel_degree + hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) + + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + logits = parallel_matmul( + hidden_states, + self.weight, + tensor_parallel_output=tensor_parallel_output, + transpose_y=self.config.tie_word_embeddings, + ) + return logits + + +class GemmaForCausalLM(GemmaPretrainedModel): + enable_to_static_method = True + + def __init__(self, config): + super().__init__(config) + self.config = config + self.lm_head = self.lm_head = GemmaLMHead(config) + self.gemma = GemmaModel(config) + self.criterion = GemmaPretrainingCriterion(config) + + self.tie_weights() + + def get_input_embeddings(self): + return self.gemma.embed_tokens + + def get_output_embeddings(self): + return self.lm_head + + def set_input_embeddings(self, value): + self.gemma.embed_tokens = value + + def set_decoder(self, decoder): + self.gemma = decoder + + def get_decoder(self): + return self.gemma + + def prepare_inputs_for_generation( + self, input_ids, use_cache=False, past_key_values=None, inputs_embeds=None, **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.gemma( + input_ids, # [bs, seq_len] + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] # [bs, seq_len, dim] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + + logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/paddlenlp/transformers/gemma/modeling_pp.py b/paddlenlp/transformers/gemma/modeling_pp.py new file mode 100644 index 000000000000..11528605ff97 --- /dev/null +++ b/paddlenlp/transformers/gemma/modeling_pp.py @@ -0,0 +1,314 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed.fleet as fleet +import paddle.nn as nn +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + PipelineLayer, + SharedLayerDesc, +) +from paddle.distributed.fleet.utils import recompute + +from paddlenlp.transformers.model_utils import PipelinePretrainedModel + +from .modeling import ( + GemmaConfig, + GemmaDecoderLayer, + GemmaLMHead, + GemmaModel, + GemmaPretrainedModel, + GemmaPretrainingCriterion, + GemmaRMSNorm, + build_alibi_tensor, +) + + +def __repr__(self): + return self.layer_func.__name__ + + +# hack LayerDesc for showing to much config +LayerDesc.__repr__ = __repr__ + +__all__ = [ + "GemmaForCausalLMPipe", +] + + +def parse_args(args): + if isinstance(args, tuple): + if len(args) == 4: + hidden_states, attention_mask, position_ids, alibi = args + if len(args) == 3: + hidden_states, attention_mask, position_ids = args + alibi = None + elif len(args) == 2: + hidden_states, attention_mask = args + position_ids = None + alibi = None + else: + hidden_states = args + attention_mask, position_ids, alibi = None, None, None + + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + if alibi is not None: + alibi.stop_gradient = True + + return hidden_states, attention_mask, position_ids, alibi + + +def return_args(hidden_states, attention_mask=None, position_ids=None, alibi=None): + ret = (hidden_states,) + + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if alibi is not None: + ret += (alibi.clone(),) + + if len(ret) == 1: + ret = ret[0] + + return ret + + +def get_attr(layer, name): + if getattr(layer, name, None) is not None: + return getattr(layer, name, None) + else: + return get_attr(layer._layer, name) + + +class GemmaEmbeddingPipe(nn.Layer): + """Extends GemmaEmbeddings to forward attention_mask through the pipeline.""" + + def __init__(self, config): + super(GemmaEmbeddingPipe, self).__init__() + self.config = config + self.sequence_parallel = config.sequence_parallel + self.hidden_size = config.hidden_size + if config.tensor_parallel_degree > 1: + self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + @property + def embedding_weight(self): + return get_attr(self.embed_tokens, "weight") + + def forward(self, args): + """_summary_ + + Args: + input (_type_): _description_ + + Returns: + _type_: _description_ + """ + input_ids, attention_mask, position_ids, alibi = parse_args(args) + input_embeds = self.embed_tokens(input_ids) + if self.sequence_parallel: + from paddlenlp.transformers import ScatterOp + + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = input_embeds.shape + input_embeds = paddle.reshape_(input_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + input_embeds = ScatterOp.apply(input_embeds) + + batch_size, seq_length = input_ids.shape + alibi = None + if self.config.alibi: + # embed positions + mask = ( + attention_mask + if attention_mask is not None + else paddle.ones((batch_size, seq_length), dtype=paddle.bool) + ) + alibi = build_alibi_tensor(mask, self.config.num_attention_heads, dtype=input_embeds.dtype) + + if self.config.tensor_parallel_degree > 1: + block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree + alibi = alibi[ + :, + self.config.tensor_parallel_rank + * block_size : (self.config.tensor_parallel_rank + 1) + * block_size, + ] + alibi = alibi.reshape([batch_size * block_size, 1, seq_length]) + else: + alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length]) + alibi.stop_gradient = True + + if attention_mask is not None: + attention_mask = GemmaModel._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), 0, input_embeds.dtype + ) + attention_mask.stop_gradient = True + + if self.config.alibi and attention_mask is None: + attention_mask = GemmaModel._prepare_decoder_attention_mask( + None, (batch_size, seq_length), 0, input_embeds.dtype + ) + attention_mask.stop_gradient = True + + hidden_states = input_embeds * (self.config.hidden_size**0.5) + return return_args(hidden_states, attention_mask, position_ids, alibi) + + +class GemmaDecoderLayerPipe(GemmaDecoderLayer): + def forward(self, args): + hidden_states, attention_mask, position_ids, alibi = parse_args(args) + # we can't distinguish + # hidden_states, attention_mask, position_ids or + # hidden_states, attention_mask, alibi + if self.config.alibi and alibi is None and position_ids is not None: + alibi = position_ids + position_ids = None + + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + if attention_mask is not None or alibi is not None: + hidden_states = recompute( + super().forward, hidden_states, attention_mask=attention_mask, alibi=alibi, use_reentrant=False + ) + else: + # for pretrain + hidden_states = recompute( + super().forward, hidden_states, use_reentrant=self.config.recompute_use_reentrant + ) + else: + hidden_states = super().forward(hidden_states, attention_mask=attention_mask, alibi=alibi) + + return return_args(hidden_states, attention_mask, position_ids, alibi) + + +class GemmaRMSNormPipe(nn.Layer): + def __init__(self, config): + super().__init__() + self.norm = GemmaRMSNorm(config) + + def forward(self, args): + hidden_states, attention_mask, position_ids, alibi = parse_args(args) + return self.norm(hidden_states) + + +class GemmaLMHeadPipe(GemmaLMHead): + def __init__(self, config): + super(GemmaLMHeadPipe, self).__init__(config) + + @property + def embedding_weight(self): + return get_attr(self, "weight") + + +class GemmaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): + """GemmaForPretraining adapted for pipeline parallelism. + + The largest change is flattening the GemmaModel class so we can express it as a + sequence of layers including embedding, transformer layers, and output. + """ + + config_class = GemmaConfig + + _get_tensor_parallel_mappings = GemmaPretrainedModel._get_tensor_parallel_mappings + _init_weights = GemmaPretrainedModel._init_weights + _keys_to_ignore_on_load_unexpected = GemmaPretrainedModel._keys_to_ignore_on_load_unexpected + + # DONOT Add base_model_prefix !!!! + + def __init__(self, config): + self.config = config + + self.use_recompute = self.config.use_recompute + self.recompute_granularity = self.config.recompute_granularity + self.pp_recompute_interval = self.config.pp_recompute_interval + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + if self.recompute_granularity == "full": + assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" + + virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + + def get_hcg(): + return fleet.get_hybrid_communicate_group() + + hcg = get_hcg() + tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) + tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) + + # TODO: fix tensor_parallel_degree rewrite in here + config.tensor_parallel_degree = tensor_parallel_degree + config.tensor_parallel_rank = tensor_parallel_rank + + self.add_sequential_layer( + SharedLayerDesc( + key="gemma_weigt_share", + layer_func=GemmaEmbeddingPipe, + shared_weight_attr="embedding_weight", + config=config, + ), + "gemma", + ) + for i in range(config.num_hidden_layers): + self.add_sequential_layer( + LayerDesc(GemmaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers), + f"gemma.layers.{i}", + ) + + self.add_sequential_layer(LayerDesc(GemmaRMSNormPipe, config=config), "gemma") + self.add_sequential_layer( + SharedLayerDesc( + key="gemma_weigt_share", + layer_func=GemmaLMHeadPipe, + shared_weight_attr="embedding_weight", + config=config, + ), + "lm_head", + ) + + recompute_interval = 0 + + seg_method = "layer:GemmaDecoderLayer" + if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0: + seg_method = "uniform" + + PipelineLayer.__init__( + self, + layers=self.get_sequential_layers(), + loss_fn=GemmaPretrainingCriterion(config), + topology=get_hcg().topology(), + seg_method=seg_method, + recompute_interval=recompute_interval, + recompute_ctx={ + "mp_group": get_hcg().get_model_parallel_group(), + "offload": False, + "partition": False, + }, + num_virtual_pipeline_stages=virtual_pp_degree, + ) + self.apply(self._init_weights) + # DON'T init PipelinePretrainedModel + # PipelinePretrainedModel.__init__(self.super(), config=config) diff --git a/paddlenlp/transformers/gemma/tokenizer.py b/paddlenlp/transformers/gemma/tokenizer.py new file mode 100644 index 000000000000..bf0804bbe236 --- /dev/null +++ b/paddlenlp/transformers/gemma/tokenizer.py @@ -0,0 +1,360 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import sentencepiece as spm + +from ...utils.log import logger +from .. import PretrainedTokenizer +from ..tokenizer_utils_base import ( + AddedToken, + BatchEncoding, + EncodedInput, + PaddingStrategy, +) + +__all__ = ["GemmaTokenizer"] + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +SPIECE_UNDERLINE = "▁" + + +class GemmaTokenizer(PretrainedTokenizer): + model_input_names = ["input_ids", "attention_mask"] + resource_files_names = VOCAB_FILES_NAMES + pretrained_resource_files_map = { + "vocab_file": { + "google/gemma-7b": "https://bj.bcebos.com/paddlenlp/models/community/google/gemma-7b/tokenizer.model", + "google/gemma-2b": "https://bj.bcebos.com/paddlenlp/models/community/google/gemma-2b/tokenizer.model", + }, + } + + pretrained_init_configuration = { + "google/gemma-7b": {}, + } + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False) if isinstance(pad_token, str) else pad_token + + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + # Copied from transformers.models.llama.tokenizer_llama.LlamaTokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.vocab_size + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. The Gemma tokenizer never adds a prefix space. + """ + return self.sp_model.encode(text, out_type=str) + + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + sub_texts = [] + current_sub_text = [] + for ids in token_ids: + if skip_special_tokens and ids in self.all_special_ids: + continue + if ids in self.added_tokens_decoder: + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + cur_id = self.added_tokens_decoder[ids] + if isinstance(cur_id, AddedToken): + sub_texts.append(cur_id.content) + elif isinstance(cur_id, str): + sub_texts.append(cur_id) + current_sub_text = [] + elif ids in self.all_special_ids: + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + sub_texts.append(self._convert_id_to_token(ids)) + current_sub_text = [] + else: + current_sub_text.append(ids) + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + + if spaces_between_special_tokens: + sub_texts = " ".join(sub_texts) + else: + sub_texts = "".join(sub_texts) + + return sub_texts + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.added_tokens_encoder: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + elif token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + For Zero Padding, Copied from llama + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + + # attention_mask shape [1,seq_len,seq_len] + if "attention_mask" in encoded_inputs and len(np.shape(encoded_inputs["attention_mask"])) > 2: + attention_mask = encoded_inputs["attention_mask"] + encoded_inputs.pop("attention_mask") + else: + attention_mask = None + + required_input = encoded_inputs[self.model_input_names[0]] + encoded_inputs = super()._pad( + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + ) + if attention_mask is not None and len(np.shape(attention_mask)) > 2: + encoded_inputs["attention_mask"] = attention_mask + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + if needs_to_be_padded: + difference = max_length - len(required_input) + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = np.pad( + encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode="constant", + constant_values=0, + ) + return encoded_inputs diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 1ddd7e1c2913..edc1bb2d3439 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -2202,10 +2202,18 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): dtype == "float16" or dtype == "bfloat16" ) - if is_sharded: - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - else: + if state_dict is not None: loaded_state_dict_keys = [k for k in state_dict.keys()] + # will only support load paddle.Tensor to model. + for k in list(state_dict.keys()): + if not isinstance(state_dict[k], paddle.Tensor): + with device_guard(): + state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True) + else: + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + loaded_state_dict_keys = [k for k in state_dict.keys()] if low_cpu_mem_usage: # or use_keep_in_fp32_modules: state_dict = None diff --git a/paddlenlp/transformers/tokenizer_utils_base.py b/paddlenlp/transformers/tokenizer_utils_base.py index eeb99117a6d3..77cf08a1b5d2 100644 --- a/paddlenlp/transformers/tokenizer_utils_base.py +++ b/paddlenlp/transformers/tokenizer_utils_base.py @@ -56,6 +56,7 @@ class AddedToken: lstrip: bool = False rstrip: bool = False normalized: bool = True + special: bool = True def __getstate__(self): return self.__dict__ diff --git a/tests/transformers/gemma/__init__.py b/tests/transformers/gemma/__init__.py new file mode 100644 index 000000000000..fd05a9208165 --- /dev/null +++ b/tests/transformers/gemma/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/gemma/test_modeling.py b/tests/transformers/gemma/test_modeling.py new file mode 100644 index 000000000000..d187d92696df --- /dev/null +++ b/tests/transformers/gemma/test_modeling.py @@ -0,0 +1,378 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest + +import paddle + +from paddlenlp.transformers import GemmaConfig, GemmaForCausalLM, GemmaModel +from tests.transformers.test_configuration_common import ConfigTester +from tests.transformers.test_generation_utils import GenerationTesterMixin +from tests.transformers.test_modeling_common import ( + ModelTesterMixin, + ids_tensor, + random_attention_mask, +) + + +class GemmaModelTester: + def __init__( + self, + parent, + vocab_size=256000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=16, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + is_training=True, + use_cache=False, + bos_token_id=2, + eos_token_id=1, + pad_token_id=0, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.0, + attention_dropout=0.0, + attention_softmax_in_fp32=True, + pretraining_tp=1, + dtype="float16", + slow_but_exact=False, + batch_size: int = 2, + seq_length: int = 10, + type_sequence_label_size=2, + activation_function="gelu", + num_labels=3, + num_choices=4, + scope=None, + dropout=0.00, + use_input_mask: bool = False, + use_labels: bool = False, + return_dict=False, + ): + self.parent: GemmaModelTest = parent + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.is_training = is_training + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.pretraining_tp = pretraining_tp + self.dtype = dtype + self.slow_but_exact = slow_but_exact + + self.batch_size = batch_size + self.seq_length = seq_length + self.type_sequence_label_size = type_sequence_label_size + self.activation_function = activation_function + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.dropout = dropout + + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.return_dict = return_dict + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype=paddle.int64) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self) -> GemmaConfig: + return GemmaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + layer_norm_epsilon=self.layer_norm_epsilon, + initializer_range=self.initializer_range, + use_cache=self.use_cache, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, + hidden_dropout=self.hidden_dropout, + attention_dropout=self.attention_dropout, + attention_softmax_in_fp32=self.attention_softmax_in_fp32, + pretraining_tp=self.pretraining_tp, + dtype=self.dtype, + slow_but_exact=self.slow_but_exact, + activation_function=self.activation_function, + ) + + def create_and_check_model( + self, config: GemmaConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = GemmaModel(config) + model.eval() + result = model(input_ids) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + + def create_and_check_model_attention_mask( + self, config: GemmaConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = GemmaModel(config) + model.eval() + attn_mask_2d = random_attention_mask([self.batch_size, self.seq_length]) + result_2d = model(input_ids, attention_mask=attn_mask_2d)[0] + batch, seq_length = input_ids.shape + causal_mask = paddle.tril(paddle.ones((batch, seq_length, seq_length), dtype=attn_mask_2d.dtype)) + attn_mask_3d = causal_mask & attn_mask_2d.unsqueeze(-1) + result_3d = model(input_ids, attention_mask=attn_mask_3d)[0] + attn_mask_4d = attn_mask_3d.unsqueeze(1) + result_4d = model(input_ids, attention_mask=attn_mask_4d)[0] + result_no_attention_mask = model(input_ids, attention_mask=None)[0] + # Assert non-padding tokens have the same logits with different attention_mask shape + self.parent.assertTrue((result_2d[attn_mask_2d] == result_3d[attn_mask_2d]).all()) + self.parent.assertTrue((result_2d[attn_mask_2d] == result_4d[attn_mask_2d]).all()) + self.parent.assertTrue((result_2d[attn_mask_2d] == result_no_attention_mask[attn_mask_2d]).all()) + + def create_and_check_model_past_large_inputs( + self, + config: GemmaConfig, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = GemmaModel(config) + model.eval() + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True, return_dict=self.return_dict) + past_key_values = outputs.past_key_values if self.return_dict else outputs[2] + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), self.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = paddle.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = paddle.concat([input_mask, next_mask], axis=-1) + + outputs = model( + next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True, return_dict=self.return_dict + ) + + output_from_no_past = outputs[2][0] + + outputs = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + return_dict=self.return_dict, + ) + + output_from_past = outputs[2][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(paddle.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args): + model = GemmaForCausalLM(config) + model.eval() + + result = model( + input_ids, + use_cache=True, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertIsInstance(result[0].item(), float) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size]) + else: + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + + def check_model_position_ids(self, config, input_ids, input_mask, *args): + model = GemmaForCausalLM(config) + model.eval() + + result_no_position_id = model( + input_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + batch_size, seq_len = input_ids.shape + position_ids = paddle.arange(seq_len).expand((batch_size, seq_len)) + result_position_id = model( + input_ids, + position_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertTrue((result_position_id[1] == result_no_position_id[1]).all()) + else: + self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all()) + + def check_model_position_ids_alibi(self, config, input_ids, input_mask, *args): + config.alibi = True + model = GemmaForCausalLM(config) + model.eval() + + result_no_position_id = model( + input_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + batch_size, seq_len = input_ids.shape + position_ids = paddle.arange(seq_len).expand((batch_size, seq_len)) + result_position_id = model( + input_ids, + position_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertTrue((result_position_id[1] == result_no_position_id[1]).all()) + else: + self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all()) + + def create_and_check_gqa_model(self, config, input_ids, input_mask, *args): + model = GemmaForCausalLM(config) + config.num_key_value_heads = 8 # gqa + config.use_fused_rope = True + model.eval() + + result = model( + input_ids, + use_cache=True, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertIsInstance(result[0].item(), float) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size]) + else: + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + + def create_and_check_mqa_model(self, config, input_ids, input_mask, *args): + model = GemmaForCausalLM(config) + config.num_key_value_heads = 1 # mqa for gemma-2b + config.use_fused_rope = True + model.eval() + result = model(input_ids) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + + +class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + base_model_class = GemmaModel + return_dict = False + use_labels = False + + all_model_classes = (GemmaModel, GemmaForCausalLM) + all_generative_model_classes = {GemmaForCausalLM: (GemmaModel, "gemma")} + + def setUp(self): + super().setUp() + + self.model_tester = GemmaModelTester(self) + self.config_tester = ConfigTester(self, config_class=GemmaConfig, vocab_size=256, hidden_size=24) + + def _get_input_ids_and_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_ids = inputs_dict[self.input_name] + attention_mask = paddle.ones_like(input_ids, dtype=paddle.int64) + + max_batch_size = 2 + sequence_length = input_ids.shape[-1] // 2 + input_ids = input_ids[:max_batch_size, :sequence_length] + attention_mask = attention_mask[:max_batch_size, :sequence_length] + max_length = 3 + + return config, input_ids, attention_mask, max_length + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_attention_mask(*config_and_inputs) + + def test_model_position_ids(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_model_position_ids(*config_and_inputs) + self.model_tester.check_model_position_ids_alibi(*config_and_inputs) + + def test_generate_without_input_ids(self): + # this requires 4-D attention mask logic, which is not supported yet + pass + + def test_gemma_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + + def test_gemma_gqa_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gqa_model(*config_and_inputs) + + def test_gemma_mqa_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mqa_model(*config_and_inputs) + + def test_model_name_list(self): + # no need for this + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transformers/gemma/test_tokenizer.py b/tests/transformers/gemma/test_tokenizer.py new file mode 100644 index 000000000000..e8527c40ee4b --- /dev/null +++ b/tests/transformers/gemma/test_tokenizer.py @@ -0,0 +1,225 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddlenlp.transformers.gemma.tokenizer import GemmaTokenizer +from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer + +from ..test_tokenizer_common import TokenizerTesterMixin + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + + +class GemmaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = GemmaTokenizer + test_decode_token = True + + def get_tokenizer(self, **kwargs) -> PretrainedTokenizer: + tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", **kwargs) + return tokenizer + + def get_input_output_texts(self, tokenizer): + input_text = "lower newer" + output_text = "lower newer" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.get_tokenizer() + text = "lower newer" + bpe_tokens = ["lower", "▁newer"] + tokens = tokenizer.tokenize(text, add_prefix_space=True) + self.assertListEqual(tokens, bpe_tokens) + input_tokens = tokens + [tokenizer.unk_token] + input_bpe_tokens = [15964, 36649, 3] + self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + def test_pretokenized_inputs(self, *args, **kwargs): + pass + + def test_tokenizers_common_ids_setters(self, *args, **kwargs): + pass + + def test_mask_output(self): + pass + + def test_offsets_mapping(self): + pass + + def test_offsets_mapping_with_unk(self): + pass + + def test_special_tokens_mask(self): + pass + + def test_special_tokens_mask_input_pairs(self): + pass + + def test_padding_side_in_kwargs(self): + tokenizer = self.get_tokenizer(padding_side="left") + self.assertEqual(tokenizer.padding_side, "left") + + tokenizer = self.get_tokenizer(padding_side="right") + self.assertEqual(tokenizer.padding_side, "right") + + def test_truncation_side_in_kwargs(self): + tokenizer = self.get_tokenizer(truncation_side="left") + self.assertEqual(tokenizer.truncation_side, "left") + + tokenizer = self.get_tokenizer(truncation_side="right") + self.assertEqual(tokenizer.truncation_side, "right") + + def test_add_tokens(self): + tokenizer = self.get_tokenizer() + + vocab_size = len(tokenizer) + self.assertEqual(tokenizer.add_tokens(""), 0) + self.assertEqual(tokenizer.add_tokens("testoken"), 1) + self.assertEqual(tokenizer.add_tokens(["testoken1", "testtoken2"]), 2) + self.assertEqual(len(tokenizer), vocab_size + 3) + + self.assertEqual(tokenizer.add_special_tokens({}), 0) + self.assertRaises(AssertionError, tokenizer.add_special_tokens, {"additional_special_tokens": ""}) + self.assertEqual(tokenizer.add_special_tokens({"additional_special_tokens": [""]}), 1) + self.assertEqual( + tokenizer.add_special_tokens({"additional_special_tokens": ["", ""]}), 2 + ) + self.assertIn("", tokenizer.special_tokens_map["additional_special_tokens"]) + self.assertIsInstance(tokenizer.special_tokens_map["additional_special_tokens"], list) + self.assertGreaterEqual(len(tokenizer.special_tokens_map["additional_special_tokens"]), 2) + + self.assertEqual(len(tokenizer), vocab_size + 6) + + def test_add_tokens_tokenizer(self): + tokenizer = self.get_tokenizer() + + vocab_size = tokenizer.vocab_size + all_size = len(tokenizer) + + self.assertNotEqual(vocab_size, 0) + + new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"] + added_toks = tokenizer.add_tokens(new_toks) + vocab_size_2 = tokenizer.vocab_size + all_size_2 = len(tokenizer) + + self.assertNotEqual(vocab_size_2, 0) + self.assertEqual(vocab_size, vocab_size_2) + self.assertEqual(added_toks, len(new_toks)) + self.assertEqual(all_size_2, all_size + len(new_toks)) + + tokens = tokenizer.encode( + "aaaaa bbbbbb low cccccccccdddddddd l", return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + self.assertGreaterEqual(len(tokens), 4) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + + def test_consecutive_unk_string(self): + tokenizer = self.get_tokenizer(add_bos_token=False) + + tokens = [tokenizer.unk_token for _ in range(2)] + string = tokenizer.convert_tokens_to_string(tokens) + encoding = tokenizer( + text=string, + runcation=True, + return_offsets_mapping=True, + ) + self.assertEqual(len(encoding["input_ids"]), 2) + self.assertEqual(len(encoding["offset_mapping"]), 2) + + def test_padding_if_pad_token_set_slow(self): + tokenizer = self.get_tokenizer() + + # Simple input + s = "This is a simple input" + s2 = ["This is a simple input looooooooong", "This is a simple input"] + p = ("This is a simple input", "This is a pair") + + pad_token_id = tokenizer.pad_token_id + + out_s = tokenizer(s, padding="max_length", max_length=30, return_tensors="np", return_attention_mask=True) + out_s2 = tokenizer(s2, padding=True, truncate=True, return_tensors="np", return_attention_mask=True) + out_p = tokenizer(*p, padding="max_length", max_length=60, return_tensors="np", return_attention_mask=True) + + # s + # test single string max_length padding + + self.assertEqual(out_s["input_ids"].shape[-1], 30) + self.assertTrue(pad_token_id in out_s["input_ids"]) + self.assertTrue(0 in out_s["attention_mask"]) + + # s2 + # test automatic padding + self.assertEqual(out_s2["input_ids"].shape[-1], 9) + # long slice doesn't have padding + self.assertFalse(pad_token_id in out_s2["input_ids"][0]) + self.assertFalse(0 in out_s2["attention_mask"][0]) + # short slice does have padding + self.assertTrue(pad_token_id in out_s2["input_ids"][1]) + self.assertTrue(0 in out_s2["attention_mask"][1]) + + # p + # test single pair max_length padding + self.assertEqual(out_p["input_ids"].shape[-1], 60) + self.assertTrue(pad_token_id in out_p["input_ids"]) + self.assertTrue(0 in out_p["attention_mask"]) + + def test_add_bos_token_slow(self): + tokenizer = self.get_tokenizer() + + s = "This is a simple input" + s2 = ["This is a simple input 1", "This is a simple input 2"] + + bos_token_id = tokenizer.bos_token_id + + out_s = tokenizer(s, add_special_tokens=True) + out_s2 = tokenizer(s2, add_special_tokens=True) + + self.assertEqual(out_s.input_ids[0], bos_token_id) + self.assertTrue(all(o[0] == bos_token_id for o in out_s2["input_ids"])) + + def test_pretrained_model_lists(self): + # No max_model_input_sizes + self.assertGreaterEqual(len(self.tokenizer_class.pretrained_resource_files_map), 1) + self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_resource_files_map.values())[0]), 1) + + def test_add_special_tokens(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + input_text, ids = "A", [235280] + + special_token = "[SPECIAL_TOKEN]" + + tokenizer.add_special_tokens({"cls_token": special_token}) + encoded_special_token = tokenizer.encode( + special_token, return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + self.assertEqual(len(encoded_special_token), 1) + + text = tokenizer.decode(ids + encoded_special_token, clean_up_tokenization_spaces=False) + encoded = tokenizer.encode(text, return_token_type_ids=None, add_special_tokens=False)["input_ids"] + + input_encoded = tokenizer.encode(input_text, return_token_type_ids=None, add_special_tokens=False)[ + "input_ids" + ] + special_token_id = tokenizer.encode( + special_token, return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + self.assertEqual(encoded, input_encoded + special_token_id) + decoded = tokenizer.decode(encoded, skip_special_tokens=True) + self.assertTrue(special_token not in decoded) diff --git a/tests/transformers/test_tensor_parallel.py b/tests/transformers/test_tensor_parallel.py index e3f02d20e242..def064a225f3 100644 --- a/tests/transformers/test_tensor_parallel.py +++ b/tests/transformers/test_tensor_parallel.py @@ -118,6 +118,15 @@ def _test_bloom(): common_test_merge(model, BloomForCausalLM) +def _test_gemma(): + from paddlenlp.transformers import GemmaConfig, GemmaForCausalLM + + config = GemmaConfig() + config = prepare_config(config) + model = GemmaForCausalLM.from_config(config) + common_test_merge(model, GemmaForCausalLM) + + # _test_llama() # _test_chatglm() # _test_bloom() @@ -129,3 +138,4 @@ def test_model_load_merge(self): _test_llama() _test_chatglm() _test_bloom() + _test_gemma()