diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 019a6b140c97..71e96c5b0d3e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -406,13 +406,18 @@ def _rescale_out_lse(out, block_out, lse, block_lse): class RingAttention(torch.autograd.Function): """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` (https://arxiv.org/abs/2310.01889). - For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main - For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper, - which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370; - implemented in Jax and not optimized). - We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available + For load-balancing, we adopted the "zigzag" dataloading scheme from ring-flash-attention. + We also adopt the double ring topology from LoongTrain to fully utilize available NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next ring at once. + Our implementation references code from + - ring-flash-attention: https://github.com/zhuzilin/ring-flash-attention/tree/main + - Megatron Context Parallel: https://github.com/NVIDIA/TransformerEngine/pull/726 + References: + - Ring Attention with Blockwise Transformers for Near-Infinite Context + https://arxiv.org/abs/2310.01889 + - LoongTrain: Efficient Training of Long-Sequence LLMs with Head-Context Parallelism + https://arxiv.org/abs/2406.18485 """ # Globle cache to avoid recomputation for same-lengthed sequences