diff --git a/csrc/generation/flash_attn_bwd.cc b/csrc/generation/flash_attn_bwd.cc new file mode 100644 index 000000000000..3acd55cbdbe9 --- /dev/null +++ b/csrc/generation/flash_attn_bwd.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" +#include +#include + +using paddle::Tensor; + +namespace paddle { +namespace experimental { + +PADDLE_API void flash_attn_grad(const Tensor& q, + const Tensor& k, + const Tensor& v, + const Tensor& out, + const Tensor& softmax_lse, + const Tensor& seed_offset, + const paddle::optional &attn_mask, + const Tensor& out_grad, + float dropout, + bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad); + +} +} // namespace paddle + + + +std::vector SRFlashAttnBwd(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &out, + const Tensor &softmax_lse, + const Tensor &seed_offset, + const paddle::optional &attn_mask, + const Tensor &out_grad, + float dropout, + bool causal); + + +std::vector SRFlashAttnBwd(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &out, + const Tensor &softmax_lse, + const Tensor &seed_offset, + const paddle::optional &attn_mask, + const Tensor &out_grad, + float dropout, + bool causal){ + std::vector res(3); + paddle::experimental::flash_attn_grad(q, k, v, out, softmax_lse, seed_offset, attn_mask, + out_grad, dropout, causal, &res[0], &res[1], + &res[2]); + return res; +} + + + +std::vector SRFlashAttnBwdDtype(paddle::DataType q_dtype, + paddle::DataType k_dtype, + paddle::DataType v_dtype) { + return {q_dtype, k_dtype, v_dtype}; + +} + + +std::vector> SRFlashAttnBwdInferShape( + std::vector q_shape, std::vector k_shape, + std::vector v_shape) { + return {q_shape, k_shape, v_shape}; +} + + +PD_BUILD_OP(flash_attn_bwd) + .Inputs({"q", "k", "v", "out", "softmax_lse", "seed_offset", "attn_mask", "out_grad"}) + .Outputs({"q_grad", "k_grad", "v_grad"}) + .Attrs({"dropout: float", "causal: bool"}) + .SetKernelFn(PD_KERNEL(SRFlashAttnBwd)) + .SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnBwdInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnBwdDtype)); diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 0b25ef3eac98..dc0ba9895027 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -77,6 +77,7 @@ def get_gencode_flags(): "./generation/step.cu", "./generation/quant_int8.cu", "./generation/dequant_int8.cu", + "./generation/flash_attn_bwd.cc", ], extra_compile_args={ "cxx": ["-O3"], diff --git a/docs/trainer.md b/docs/trainer.md index a1dde0af4f94..55df827bb3d7 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -576,7 +576,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并 following config is support: enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now. gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now. - + --context_parallel_degree + 上下文并行是将训练数据在序列维度进行切分的并行方法。 + 该方法使用Ring FlashAttention来保障切分后Attention结果的正确性。通过环状通信和迭代更新来得到完整的注意力分数。 + 默认值-1, 表示不启用上下文并行, + (`int`, 可选, 默认为 `-1`) + (注: 该方法需要修改模型结构, 目前支持LLAMA) + (注: 该方法对通信开销较大, 建议只有在序列长度超长时, 如1024k, 时才使用) + Context parallelism is a parallel method that segments training data in the sequence dimension. + This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates. --recompute 是否使用重计算训练。可以节省显存。 重新计算前向过程以获取梯度,减少中间变量显存. diff --git a/llm/llama/run_trainer_tp2cp2.sh b/llm/llama/run_trainer_tp2cp2.sh new file mode 100644 index 000000000000..1a684191deea --- /dev/null +++ b/llm/llama/run_trainer_tp2cp2.sh @@ -0,0 +1,84 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -x +unset CUDA_VISIBLE_DEVICES + +rm -rf log +rm -rf output + +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +# export FLAGS_embedding_deterministic=1 +# export FLAGS_cudnn_deterministic=1 +# export FLAGS_flash_attn_version=v1 +# export USE_FAST_LN=0 + + +max_seq_length=1024 + +max_steps=1000 +log_dir=seq_${max_seq_length}_log +echo "log_dir:${log_dir}" +rm -rf $log_dir + +export PYTHONPATH=../../:$PYTHONPATH +python -u -m paddle.distributed.launch \ + --gpus "3,4,5,7" \ + --log_dir "./$log_dir" \ + run_pretrain.py \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir "./output" \ + --split 949,50,1 \ + --max_seq_length $max_seq_length \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --per_device_eval_batch_size 4 \ + --bf16 \ + --fp16_opt_level "O2" \ + --use_flash_attention 1 \ + --virtual_pp_degree 1 \ + --pp_recompute_interval 1 \ + --learning_rate 0.00001 \ + --min_learning_rate 0.000001 \ + --max_steps $max_steps \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 1 \ + --dataloader_num_workers 1 \ + --eval_steps 1001 \ + --disable_tqdm true \ + --continue_training 0 \ + --do_train \ + --device "gpu" \ + --enable_linear_fused_grad_add false \ + --recompute_use_reentrant true \ + --data_cache "./data_cache" \ + --pipeline_parallel_degree 1 \ + --context_parallel_degree 2 \ + --tensor_parallel_degree 2 \ + --sequence_parallel false \ + --skip_profile_timer true \ + --amp_master_grad \ + --report_to "visualdl" \ + --logging_dir "./visualdl_log" \ + --save_steps 2000000 \ diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index e58888772a5d..0f0b3122baae 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -485,11 +485,15 @@ def main(): config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob config.sep_parallel_degree = training_args.sep_parallel_degree + config.context_parallel_degree = training_args.context_parallel_degree if config.sequence_parallel: assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel." assert ( config.num_attention_heads % config.sep_parallel_degree == 0 ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + assert ( + config.seq_length % config.context_parallel_degree == 0 + ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: try: diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01b902478622..7d5cc4a5ffc1 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -81,6 +81,7 @@ from ..quantization.quantization_linear import QuantizationLinear except: QuantizationLinear = None +from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance from ..transformers.model_utils import ( PretrainedModel, _add_variant, @@ -763,6 +764,8 @@ def train( trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size if self.args.sep_parallel_degree > 0: trainable_numel = trainable_numel // self.args.sep_parallel_degree + if self.args.context_parallel_degree > 0: + trainable_numel = trainable_numel // self.args.context_parallel_degree # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited # so, the trainable numel is a little bigger than real. logger.debug(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") @@ -897,6 +900,8 @@ def _inner_training_loop( for step, inputs in enumerate(epoch_iterator): if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: inputs = split_inputs_sequence_dim(inputs) + if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1: + inputs = split_inputs_sequence_dim_load_balance(inputs) self.timers and self.timers("read-data").stop() os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) @@ -1760,6 +1765,7 @@ def _wrap_model(self, model, training=True): in_sharding_parallel_mode = self.sharding is not None in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1 in_sep_parallel_mode = self.args.sep_parallel_degree > 1 + in_cp_parallel_mode = self.args.context_parallel_degree > 1 # Multi-gpu training if ( @@ -1770,6 +1776,7 @@ def _wrap_model(self, model, training=True): or in_sharding_parallel_mode or in_tensor_parallel_mode or in_sep_parallel_mode + or in_cp_parallel_mode ) ): model = paddle.DataParallel(model) @@ -1897,7 +1904,7 @@ def get_expected_keys(inputs, keys): if ( not in_pipeline_parallel_mode and not in_sharding_parallel_mode - and (in_tensor_parallel_mode or in_sep_parallel_mode) + and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode) ): if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 7b792ad34e7a..2aa77bdeefbb 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -230,6 +230,10 @@ class TrainingArguments: The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. ) + context_parallel_degree (`int`, *optional*, defaults to `-1`)( + Context parallelism is a parallel method that segments training data in the sequence dimension. + This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates. + ) data_parallel_config (`str`, *optional*)( Some additional configs which affect data parallel performance, we provide some option to config it. following config is support: @@ -583,6 +587,15 @@ class TrainingArguments: ) }, ) + context_parallel_degree: int = field( + default=-1, + metadata={ + "help": ( + "The paddle context parallel strategy. It can reduce the GPU memory of activation to 1/cp, and it is orthogonal to " + "data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. " + ) + }, + ) data_parallel_config: str = field( default="", metadata={ @@ -918,16 +931,24 @@ def __post_init__(self): if world_size > 1: tensor_parallel_degree = max(self.tensor_parallel_degree, 1) sep_parallel_degree = max(self.sep_parallel_degree, 1) + context_parallel_degree = max(self.context_parallel_degree, 1) pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) assert ( world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0 ), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}." + assert not ( + sep_parallel_degree > 1 and context_parallel_degree > 1 + ), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}." + if self.sharding_parallel_degree == -1: if len(self.sharding) > 0: self.sharding_parallel_degree = world_size // ( - tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree + tensor_parallel_degree + * sep_parallel_degree + * context_parallel_degree + * pipeline_parallel_degree ) sharding_parallel_degree = max(self.sharding_parallel_degree, 1) @@ -936,7 +957,11 @@ def __post_init__(self): self.sharding = [] self.data_parallel_degree = world_size // ( - sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree + sharding_parallel_degree + * tensor_parallel_degree + * sep_parallel_degree + * context_parallel_degree + * pipeline_parallel_degree ) if ( @@ -944,12 +969,14 @@ def __post_init__(self): or tensor_parallel_degree > 1 or pipeline_parallel_degree > 1 or self.sep_parallel_degree > 1 + or self.context_parallel_degree > 1 ): self.use_hybrid_parallel = True self.sharding_parallel_degree = sharding_parallel_degree self.tensor_parallel_degree = tensor_parallel_degree self.pipeline_parallel_degree = pipeline_parallel_degree self.sep_parallel_degree = sep_parallel_degree + self.context_parallel_degree = context_parallel_degree if not self.use_hybrid_parallel: self.sharding = [] @@ -957,6 +984,7 @@ def __post_init__(self): self.tensor_parallel_degree = -1 self.pipeline_parallel_degree = -1 self.sep_parallel_degree = -1 + self.context_parallel_degree = -1 if self.hybrid_parallel_topo_order is None: self.hybrid_parallel_topo_order = "pp_first" @@ -1157,7 +1185,9 @@ def is_segment_parallel_supported(): "mp_degree": self.tensor_parallel_degree, "pp_degree": self.pipeline_parallel_degree, "sharding_degree": self.sharding_parallel_degree, - "sep_degree": self.sep_parallel_degree, + "sep_degree": self.sep_parallel_degree + if self.sep_parallel_degree > 1 + else self.context_parallel_degree, "order": order, } else: @@ -1241,6 +1271,7 @@ def is_segment_parallel_supported(): elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) + self.context_parallel_degree = max(self.context_parallel_degree, 1) self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) assert ( @@ -1250,7 +1281,10 @@ def is_segment_parallel_supported(): if self.sharding_parallel_degree == -1: if len(self.sharding) > 0: self.sharding_parallel_degree = world_size // ( - self.tensor_parallel_degree * self.sep_parallel_degree * self.pipeline_parallel_degree + self.tensor_parallel_degree + * self.sep_parallel_degree + * self.context_parallel_degree + * self.pipeline_parallel_degree ) self.sharding_parallel_degree = max(self.sharding_parallel_degree, 1) @@ -1262,6 +1296,7 @@ def is_segment_parallel_supported(): self.sharding_parallel_degree * self.tensor_parallel_degree * self.sep_parallel_degree + * self.context_parallel_degree * self.pipeline_parallel_degree ) diff --git a/paddlenlp/transformers/configuration_utils.py b/paddlenlp/transformers/configuration_utils.py index 4bda24695a48..093ea32e3bf6 100644 --- a/paddlenlp/transformers/configuration_utils.py +++ b/paddlenlp/transformers/configuration_utils.py @@ -465,8 +465,9 @@ def __init__(self, **kwargs): # Parameters for tensor parallel self.tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", -1) self.tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0) - # Parameters for sep + # Parameters for sep and cp self.sep_parallel_degree = kwargs.pop("sep_parallel_degree", -1) + self.context_parallel_degree = kwargs.pop("context_parallel_degree", -1) # If set to True, this option is used with fleet.meta_parallel.ParallelCrossEntropy # to calculate cross-entropy loss for parallel model. self.tensor_parallel_output = kwargs.pop("tensor_parallel_output", False) diff --git a/paddlenlp/transformers/context_parallel_utils.py b/paddlenlp/transformers/context_parallel_utils.py new file mode 100644 index 000000000000..7f8a69352764 --- /dev/null +++ b/paddlenlp/transformers/context_parallel_utils.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +from paddle.distributed.fleet import fleet + + +def split_inputs_sequence_dim_load_balance(inputs, rank=None, degree=None): + if degree is None and rank is None: + _hcg = fleet.get_hybrid_communicate_group() + degree = _hcg.get_sep_parallel_world_size() + rank = _hcg.get_sep_parallel_rank() + assert isinstance(degree, int) and isinstance( + rank, int + ), f"degree:{type(degree)} and rank:{type(rank)} must be int" + if degree <= 1: + return inputs + + def do_split_sequence_dim_load_balance(data, rank, degree): + if data is None: + return None + assert isinstance(data, paddle.Tensor), f"data should be paddle.Tensor, but is type:{type(data)}" + assert len(data.shape) == 2, f"data dims should be 2, but shaped: {data.shape}" + sliced_datas = paddle.split(data, num_or_sections=degree * 2, axis=-1) + sliced_data0, sliced_data1 = sliced_datas[rank], sliced_datas[degree * 2 - 1 - rank] + return paddle.concat([sliced_data0, sliced_data1], axis=-1) + + if isinstance(inputs, paddle.Tensor): + return do_split_sequence_dim_load_balance(inputs, rank, degree) + elif isinstance(inputs, dict): + res = {} + for k, tensor in inputs.items(): + res[k] = do_split_sequence_dim_load_balance(tensor, rank, degree) + elif isinstance(inputs, list): + res = [] + for tensor in inputs: + res.append(do_split_sequence_dim_load_balance(tensor, rank, degree)) + else: + raise ValueError(f"the inputs should be a list or a dict, but is type: {type(inputs)}") + return res diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 6009a80911d5..182663bdbc73 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -51,12 +51,26 @@ def swiglu(x, y=None): except: flash_attention = None +from paddlenlp.transformers.ring_flash_attention import RingFlashAttention -def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb): + +def fusion_rope( + query_states, + key_states, + value_states, + hidden_states, + position_ids, + past_key_value, + rotary_emb, + context_parallel_degree=-1, +): if get_env_device() != "gcu": assert past_key_value is None, "fuse rotary not support cache kv for now" batch_size, seq_length, num_heads, head_dim = query_states.shape _, kv_seq_len, num_key_value_heads, _ = key_states.shape + if context_parallel_degree > 1: + assert get_env_device() == "gpu", "context parallel only support cuda device for now" + kv_seq_len *= context_parallel_degree if get_env_device() != "gcu": cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) if get_env_device() == "npu": @@ -142,6 +156,8 @@ def fusion_flash_attention( if version != "0.0.0" and version <= "2.5.2": if alibi is not None: raise ValueError("Flash Attention doesn't support alibi") + if config.context_parallel_degree > 1: + raise ValueError(f"Context parallel is not implemented in version {version}") attn_output, attn_weights = flash_attention( query_states, key_states, @@ -154,6 +170,8 @@ def fusion_flash_attention( alibi = alibi.reshape([bsz, num_heads, 1, -1]) attention_mask = attention_mask.cast(alibi.dtype) + alibi if get_env_device() == "npu": + if config.context_parallel_degree > 1: + raise ValueError("Context parallel is not implemented for npu") attn_output = core.eager._run_custom_op( "flash_attention_npu", query_states, @@ -168,6 +186,8 @@ def fusion_flash_attention( npu_is_casual, )[0] elif get_env_device() == "gcu": + if config.context_parallel_degree > 1: + raise ValueError("Context parallel is not implemented for gcu") attn_output = core.eager._run_custom_op( "fused_sdp_flash_attention_gcu", query_states, @@ -179,13 +199,22 @@ def fusion_flash_attention( True, )[0] else: - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=attention_mask is None, - ) + if config.context_parallel_degree > 1: + attn_output = RingFlashAttention.apply( + query_states, + key_states, + value_states, + attn_mask=None, + is_causal=True, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + ) attn_weights = None if reshard_layer is not None: diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 366f7ff3c083..cdc1a2abe845 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -233,6 +233,9 @@ def scaled_dot_product_attention( # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] else: + if config.context_parallel_degree > 1: + raise ValueError("Context parallel requires `use_flash_attention=True`") + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] query_states = paddle.transpose(query_states, [0, 2, 1, 3]) # merge with the next tranpose @@ -932,6 +935,17 @@ def forward( if self.reshard_layer is not None: batch_size, seq_length, _, _ = query_states.shape position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + if self.config.context_parallel_degree > 1: + batch_size, seq_length, _, _ = query_states.shape + group = fleet.get_hybrid_communicate_group().get_sep_parallel_group() + chunk_size = seq_length // 2 + chunk_num = group.nranks * 2 + rank = group.rank + first_chunk_ids = paddle.arange(rank * chunk_size, (rank + 1) * chunk_size, dtype="int64") + second_chunk_ids = paddle.arange( + (chunk_num - rank - 1) * chunk_size, (chunk_num - rank) * chunk_size, dtype="int64" + ) + position_ids = paddle.concat([first_chunk_ids, second_chunk_ids]).expand((batch_size, seq_length)) if self.use_fused_rope: query_states, key_states = fusion_ops.fusion_rope( query_states, @@ -941,9 +955,12 @@ def forward( position_ids, past_key_value, self.rotary_emb, + self.config.context_parallel_degree, ) else: + if self.config.context_parallel_degree > 1: + kv_seq_len *= self.config.context_parallel_degree if self.config.use_long_sequence_strategies: cos, sin = self.rotary_emb(seq_len=kv_seq_len) cos = cos[None, :, None, :] @@ -1512,6 +1529,8 @@ def forward( # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) + if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi): + raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi") # embed positions if attention_mask is None: # [bs, seq_len] @@ -1655,9 +1674,9 @@ def forward(self, prediction_scores, masked_lm_labels): with paddle.amp.auto_cast(False): masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) - if self.config.sep_parallel_degree > 1: + if self.config.sep_parallel_degree > 1 or self.config.context_parallel_degree > 1: _hcg = fleet.get_hybrid_communicate_group() - masked_lm_loss = ConcatSePMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) + masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) # skip ignore_index which loss == 0 # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] # loss = paddle.mean(masked_lm_loss) @@ -1673,7 +1692,7 @@ def forward(self, prediction_scores, masked_lm_labels): return loss -class ConcatSePMaskedLoss(PyLayer): +class ConcatMaskedLoss(PyLayer): @staticmethod def forward(ctx, inp, axis, group): inputs = [] @@ -1728,6 +1747,9 @@ def forward(self, hidden_states, tensor_parallel_output=None): if self.config.sep_parallel_degree > 1: assert seq_length % self.config.sep_parallel_degree == 0 seq_length = seq_length // self.config.sep_parallel_degree + if self.config.context_parallel_degree > 1: + assert seq_length % self.config.context_parallel_degree == 0 + seq_length = seq_length // self.config.context_parallel_degree hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) if tensor_parallel_output is None: diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py new file mode 100644 index 000000000000..3ff4d9def8d8 --- /dev/null +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -0,0 +1,386 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# paddlenlp/transformers/ring_attention.py + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import _C_ops +from paddle.autograd.py_layer import PyLayer + +try: + from paddlenlp_ops import flash_attn_bwd +except (ImportError, ModuleNotFoundError): + from paddlenlp.utils.log import logger + + logger.warning( + "if you run ring_flash_attention.py, please ensure you install " + "the paddlenlp_ops by following the instructions " + "provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" + ) + + +class RingCommunicator: + def __init__(self, group, local_key, local_value): + self._k_buffer = [paddle.zeros_like(local_key) for _ in range(2)] + self._v_buffer = [paddle.zeros_like(local_value) for _ in range(2)] + + self._k_buffer[0] = local_key.clone() + self._v_buffer[0] = local_value.clone() + + self._next_buffer_idx = 0 + + self.group = group + self.group_rank = group.rank + self.send_rank = self.group.ranks[(self.group_rank + 1) % self.group.world_size] + self.recv_rank = self.group.ranks[(self.group_rank - 1) % self.group.world_size] + + self._reqs = [] + + def wait(self): + # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。 + paddle.device.synchronize() + + def add_to_buffers(self, key, value): + if key.shape != self._k_buffer[self._next_buffer_idx].shape: + k_buffer_chunk = paddle.slice( + self._k_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[key.shape[1]] + ) + v_buffer_chunk = paddle.slice( + self._v_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[value.shape[1]] + ) + k_buffer_chunk += key + v_buffer_chunk += value + else: + self._k_buffer[self._next_buffer_idx] += key + self._v_buffer[self._next_buffer_idx] += value + + def get_buffers(self): + return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx] + + def send_recv(self): + send_k_op = dist.P2POp(dist.isend, self._k_buffer[self._next_buffer_idx], self.send_rank, self.group) + send_v_op = dist.P2POp(dist.isend, self._v_buffer[self._next_buffer_idx], self.send_rank, self.group) + recv_k_op = dist.P2POp(dist.irecv, self._k_buffer[(self._next_buffer_idx + 1) % 2], self.recv_rank, self.group) + recv_v_op = dist.P2POp(dist.irecv, self._v_buffer[(self._next_buffer_idx + 1) % 2], self.recv_rank, self.group) + + self._next_buffer_idx = (self._next_buffer_idx + 1) % 2 + + ops = [send_k_op, send_v_op, recv_k_op, recv_v_op] + + self._reqs = dist.batch_isend_irecv(ops) + + +def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False): + if old_out is None and old_lse is None: + return block_out.to("float32"), block_lse.to("float32") + + if second_chunk_only: + second_chunk_out_ = paddle.slice(old_out, axes=[1], starts=[old_out.shape[1] // 2], ends=[old_out.shape[1]]) + second_chunk_lse_ = paddle.slice(old_lse, axes=[1], starts=[old_lse.shape[1] // 2], ends=[old_lse.shape[1]]) + second_chunk_out, second_chunk_lse = update_out_and_lse( + second_chunk_out_, second_chunk_lse_, block_out, block_lse + ) + paddle.assign(second_chunk_out, second_chunk_out_) + paddle.assign(second_chunk_lse, second_chunk_lse_) + return old_out, old_lse + else: + block_out, block_lse = block_out.to("float32"), block_lse.to("float32") + with paddle.amp.auto_cast(enable=False, dtype="bfloat16"): + lse = old_lse - F.log_sigmoid(old_lse - block_lse) + return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), lse + + +def get_chunk_id(rank, cp_size): + return rank, (2 * cp_size - 1 - rank) + + +def concat_masks(attn_masks_list, rank, cp_size): + assert len(attn_masks_list) == 2 * cp_size + first_chunk_id, second_chunk_id = get_chunk_id(rank, cp_size) + return paddle.concat([attn_masks_list[first_chunk_id], attn_masks_list[second_chunk_id]], axis=3) + + +def balanced_ring_flash_attention_fwd_func( + group, + local_query, + local_key, + local_value, + fixed_seed_offset=None, + attn_mask=None, + dropout=0.0, + is_causal=False, + training=True, +): + cp_size = group.world_size + rank = group.rank + + comm_buffer = RingCommunicator(group, local_key, local_value) + local_q_seq_len = local_query.shape[1] + + out, lse, k_cache, v_cache = None, None, dict(), dict() + + if attn_mask is not None: + attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + if is_causal: + local_query_second_chunk = paddle.slice( + local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] + ) + for step in range(cp_size): + block_k, block_v = comm_buffer.get_buffers() + + if step != cp_size - 1: + comm_buffer.send_recv() + + if not is_causal: + # out [bs, seq, nhead, headdim] + # lse [bs, nhead, seq] + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, + block_k, + block_v, + fixed_seed_offset, + None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), + dropout, + False, + False, + not training, + "", + ) + block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + # block_k and block_v is from rank (group.rank - step) % cp_size + if step == 0: + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, "" + ) + block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step > rank: + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query_second_chunk, + block_k, + block_v, + fixed_seed_offset, + None, + dropout, + False, + False, + not training, + "", + ) + block_lse = paddle.slice(block_lse, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) + block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + out, lse = update_out_and_lse(out, lse, block_out, block_lse, True) + else: + block_k = paddle.slice(block_k, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) + block_v = paddle.slice(block_v, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) + block_out, _, block_lse, _ = _C_ops.flash_attn( + local_query, + block_k, + block_v, + fixed_seed_offset, + None, + dropout, + False, + False, + not training, + "", + ) + block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + k_cache[step] = block_k + v_cache[step] = block_v + + # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。 + paddle.device.synchronize() + + out = out.to(local_query.dtype) + lse = paddle.transpose_(paddle.squeeze_(lse, axis=-1), [0, 2, 1]) + return out, lse, k_cache, v_cache + + +def balanced_ring_flash_attention_bwd_func( + group, + k_cache, + v_cache, + out_grad, + local_query, + local_key, + local_value, + local_out, + lse, + fixed_seed_offset, + attn_mask, + dropout=0.0, + is_causal=False, +): + cp_size = group.world_size + rank = group.rank + local_q_seq_len = local_query.shape[1] + + query_grad_buffer = paddle.zeros_like(local_query) + key_grad_buffer = paddle.zeros_like(local_key) + value_grad_buffer = paddle.zeros_like(local_value) + + kv_comm_buffer = RingCommunicator(group, local_key, local_value) + grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) + + if is_causal: + local_query_second_chunk = paddle.slice( + local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] + ) + local_out_second_chunk = paddle.slice( + local_out, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] + ) + lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]) + out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]) + query_grad_buffer_second_chunk = paddle.slice( + query_grad_buffer, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] + ) + + if attn_mask is not None: + attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + + for step in range(cp_size): + block_k, block_v = kv_comm_buffer.get_buffers() + + if step != cp_size - 1: + kv_comm_buffer.send_recv() + + if not is_causal: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, + block_k, + block_v, + local_out, + lse, + fixed_seed_offset, + None if attn_mask is None else concat_masks(attn_masks_list, (group.rank - step) % cp_size, cp_size), + out_grad, + dropout, + False, + ) + query_grad_buffer += block_q_grad + else: + if step == 0: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, block_k, block_v, local_out, lse, fixed_seed_offset, None, out_grad, dropout, True + ) + query_grad_buffer += block_q_grad + elif step > rank: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query_second_chunk, + block_k, + block_v, + local_out_second_chunk, + lse_second_chunk, + fixed_seed_offset, + None, + out_grad_second_chunk, + dropout, + False, + ) + query_grad_buffer_second_chunk += block_q_grad + else: + block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( + local_query, + k_cache[step], + v_cache[step], + local_out, + lse, + fixed_seed_offset, + None, + out_grad, + dropout, + False, + ) + query_grad_buffer += block_q_grad + + # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。 + paddle.device.synchronize() + + grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad) + grad_comm_buffer.send_recv() + + grad_comm_buffer.wait() + key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers() + + dtype = local_query.dtype + return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype) + + +class RingFlashAttention(PyLayer): + @staticmethod + def forward( + ctx, + query, + key, + value, + group=None, + fixed_seed_offset=None, + attn_mask=None, + dropout=0.0, + is_causal=False, + training=True, + ): + if dropout > 0.0: + raise NotImplementedError("Dropout is not supported in ring attention yet.") + if group is None: + group = dist.fleet.get_hybrid_communicate_group().get_sep_parallel_group() + if attn_mask is not None: + is_causal = False + + out, lse, k_cache, v_cache = balanced_ring_flash_attention_fwd_func( + group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training + ) + ctx.save_for_backward(query, key, value, out, lse, attn_mask, k_cache, v_cache) + ctx.group = group + ctx.fixed_seed_offset = fixed_seed_offset + ctx.dropout = dropout + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, out_grad): + query, key, value, out, lse, attn_mask, k_cache, v_cache = ctx.saved_tensor() + group = ctx.group + fixed_seed_offset = ctx.fixed_seed_offset + dropout = ctx.dropout + is_causal = ctx.is_causal + + if fixed_seed_offset is None: + fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64) + + query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func( + group, + k_cache, + v_cache, + out_grad, + query, + key, + value, + out, + lse, + fixed_seed_offset, + attn_mask, + dropout, + is_causal, + ) + if attn_mask is not None and not attn_mask.stop_gradient: + return query_grad, key_grad, value_grad, None + else: + return query_grad, key_grad, value_grad diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index 32cfec4b59de..e19a42f8a756 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -1111,5 +1111,16 @@ else echo "only one gpu:${cudaid1} is set, skip test" fi +} +ring_flash_attention(){ +cd ${nlp_dir} +echo "test ring_flash_attention, cudaid1:${cudaid1}, cudaid2:${cudaid2}" +if [[ ${cudaid1} != ${cudaid2} ]]; then + time (python -m paddle.distributed.launch tests/transformers/test_ring_flash_attention.py >${log_path}/ring_flash_attention) >>${log_path}/ring_flash_attention 2>&1 + print_info $? ring_flash_attention +else + echo "only one gpu:${cudaid1} is set, skip test" +fi + } $1 diff --git a/scripts/regression/run_ci.sh b/scripts/regression/run_ci.sh index 74d0b1957af8..0f7f6fdf5ab0 100644 --- a/scripts/regression/run_ci.sh +++ b/scripts/regression/run_ci.sh @@ -33,7 +33,7 @@ all_P0case_dic=(["waybill_ie"]=3 ["msra_ner"]=15 ["glue"]=2 ["bert"]=2 ["skep"]= ["ernie-ctm"]=5 ["distilbert"]=5 ["transformer"]=5 ["pet"]=5 ["efl"]=5 ["p-tuning"]=5 ["ernie-doc"]=20 ["transformer-xl"]=5 \ ["question_matching"]=5 ["ernie-csc"]=5 ["nptag"]=5 ["ernie-m"]=5 ["taskflow"]=5 ["clue"]=5 ["textcnn"]=5 \ ["fast_generation"]=10 ["ernie-3.0"]=5 ["ernie-layout"]=5 ["uie"]=5 ["ernie-health"]=5 ["llm"]=5 \ -["ernie"]=2 ["ernie_m"]=5 ["ernie_layout"]=5 ["ernie_csc"]=5 ["ernie_ctm"]=5 ["ernie_doc"]=20 ["ernie_health"]=5 ["segment_parallel_utils"]=5) +["ernie"]=2 ["ernie_m"]=5 ["ernie_layout"]=5 ["ernie_csc"]=5 ["ernie_ctm"]=5 ["ernie_doc"]=20 ["ernie_health"]=5 ["segment_parallel_utils"]=5 ["ring_flash_attention"]=5) #################################### # Insatll paddlepaddle-gpu install_paddle(){ diff --git a/tests/transformers/test_ring_flash_attention.py b/tests/transformers/test_ring_flash_attention.py new file mode 100644 index 000000000000..134d2f9c011a --- /dev/null +++ b/tests/transformers/test_ring_flash_attention.py @@ -0,0 +1,124 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import paddle +from paddle.nn.functional.flash_attention import scaled_dot_product_attention + +from paddlenlp.transformers.ring_flash_attention import RingFlashAttention, get_chunk_id + + +class TestRingFlashAttention(unittest.TestCase): + def setUp(self): + paddle.distributed.init_parallel_env() + self.group = paddle.distributed.new_group(range(paddle.distributed.get_world_size()), backend="nccl") + self.degree = self.group.world_size + self.rank = self.group.rank + + seed = 42 + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + self.test_id = 0 + + def generate_full_data(self, batch_size, seq_len, num_head, head_dim): + query = paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.bfloat16) + key = paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.bfloat16) + value = paddle.randn([batch_size, seq_len, num_head, head_dim], dtype=paddle.bfloat16) + + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + return query, key, value + + def split_belanced_data(self, input): + sliced_datas = paddle.split(input, num_or_sections=self.degree * 2, axis=1) + sliced_data0, sliced_data1 = sliced_datas[self.rank], sliced_datas[self.degree * 2 - 1 - self.rank] + return paddle.concat([sliced_data0, sliced_data1], axis=1).detach() + + def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, use_mask): + if self.degree < 2: + return + query, key, value = self.generate_full_data(bsz, seq_len_per_device * self.degree, head_num, head_dim) + + local_query = self.split_belanced_data(query) + local_key = self.split_belanced_data(key) + local_value = self.split_belanced_data(value) + + local_query.stop_gradient = False + local_key.stop_gradient = False + local_value.stop_gradient = False + + if use_mask: + mask_shape = (bsz, 1, query.shape[1], query.shape[1]) + mask = np.random.random(mask_shape) + attn_mask = paddle.to_tensor(mask, place=query.place, dtype=query.dtype) + attn_mask = paddle.ones(mask_shape).to(query.dtype) + attn_mask_list = paddle.split(attn_mask, axis=2, num_or_sections=self.degree * 2) + first_chunk_id, second_chunk_id = get_chunk_id(self.rank, self.degree) + local_attn_mask = paddle.concat([attn_mask_list[first_chunk_id], attn_mask_list[second_chunk_id]], axis=2) + else: + attn_mask = None + local_attn_mask = None + + with paddle.amp.auto_cast(enable=True, dtype="bfloat16"): + local_out = RingFlashAttention.apply( + local_query, local_key, local_value, self.group, is_causal=is_causal, attn_mask=local_attn_mask + ) + ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask) + + local_out.mean().backward() + ref_out.mean().backward() + + ref_local_query_grad = self.split_belanced_data(query.grad) + ref_local_key_grad = self.split_belanced_data(key.grad) + ref_local_value_grad = self.split_belanced_data(value.grad) + + ref_local_out = self.split_belanced_data(ref_out) + + rtol = 1e-04 + atol = 5e-03 + np.testing.assert_allclose( + local_out.to("float32").numpy(), ref_local_out.to("float32").numpy(), rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + local_query.grad.to("float32").numpy(), ref_local_query_grad.to("float32").numpy(), rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + local_key.grad.to("float32").numpy(), ref_local_key_grad.to("float32").numpy(), rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + local_value.grad.to("float32").numpy(), ref_local_value_grad.to("float32").numpy(), rtol=rtol, atol=atol + ) + + print(f"Test {self.test_id} passed!") + self.test_id += 1 + + def test_normal_flash_attention(self): + self.single_test(2, 1024, 2, 128, False, False) + + def test_masked_flash_attention(self): + self.single_test(2, 1024, 2, 128, False, True) + + def test_casual_flash_attention(self): + self.single_test(2, 1024, 2, 128, True, False) + + +if __name__ == "__main__": + unittest.main()