Skip to content

[ROCM] enable aiter fused moe kernel for llama4 bf16 checkpoints #16674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 17, 2025

Conversation

sijiac
Copy link
Contributor

@sijiac sijiac commented Apr 15, 2025

This PR enabled aiter fused moe for bf16 checkpoints to improve the performance

NOTE: It doesn't support torch.compile() at this moment so we use eager mode for both when doing the benchmark. The issue is tracked in: ROCm/aiter#244

16E bf16 benchmark

VLLM_USE_V1=1 VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_ROCM_FP8_PADDING=0 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MOE=1 VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE=0 VLLM_ROCM_USE_AITER_RMSNORM=0 VLLM_ROCM_USE_AITER_LINEAR=0 SAFETENSORS_FAST_GPU=1 vllm serve meta-llama/Llama-4-Scout-17B-16E-Instruct --disable-log-requests -tp 8 --max-num-seqs 64 --max-model-len 8192 --compilation-config 0 --enforce-eager

python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Llama-4-Scout-17B-16E-Instruct --dataset-name random --random-input-len 1000 --random-output-len 1000 --max-concurrency 32 --num-prompts 160

baseline
============ Serving Benchmark Result ============
Successful requests:                     160
Benchmark duration (s):                  97.64
Total input tokens:                      160000
Total generated tokens:                  41939
Request throughput (req/s):              1.64
Output token throughput (tok/s):         429.51
Total Token throughput (tok/s):          2068.10
---------------Time to First Token----------------
Mean TTFT (ms):                          370.74
Median TTFT (ms):                        111.65
P99 TTFT (ms):                           1977.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          49.73
Median TPOT (ms):                        48.56
P99 TPOT (ms):                           66.96
---------------Inter-token Latency----------------
Mean ITL (ms):                           49.20
Median ITL (ms):                         45.67
P99 ITL (ms):                            85.23
==================================================

w/ the change
============ Serving Benchmark Result ============
Successful requests:                     160
Benchmark duration (s):                  77.91
Total input tokens:                      160000
Total generated tokens:                  46026
Request throughput (req/s):              2.05
Output token throughput (tok/s):         590.72
Total Token throughput (tok/s):          2644.26
---------------Time to First Token----------------
Mean TTFT (ms):                          226.57
Median TTFT (ms):                        95.78
P99 TTFT (ms):                           1244.74
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          39.30
Median TPOT (ms):                        38.46
P99 TPOT (ms):                           49.95
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.96
Median ITL (ms):                         35.44
P99 ITL (ms):                            76.00
==================================================

128E bf16 benchmark

VLLM_USE_V1=1 VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_ROCM_FP8_PADDING=0 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MOE=1 VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE=0 VLLM_ROCM_USE_AITER_RMSNORM=0 VLLM_ROCM_USE_AITER_LINEAR=0 SAFETENSORS_FAST_GPU=1 vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct --disable-log-requests -tp 8 --max-num-seqs 64 --max-model-len 8192 --compilation-config 0 --enforce-eager

python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Llama-4-Maverick-17B-128E-Instruct --dataset-name random --random-input-len 1000 --random-output-len 1000 --max-concurrency 32 --num-prompts 160

baseline
============ Serving Benchmark Result ============
Successful requests:                     160
Benchmark duration (s):                  155.83
Total input tokens:                      160000
Total generated tokens:                  132925
Request throughput (req/s):              1.03
Output token throughput (tok/s):         853.02
Total Token throughput (tok/s):          1879.78
---------------Time to First Token----------------
Mean TTFT (ms):                          83.09
Median TTFT (ms):                        75.57
P99 TTFT (ms):                           144.62
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          32.55
Median TPOT (ms):                        32.54
P99 TPOT (ms):                           33.18
---------------Inter-token Latency----------------
Mean ITL (ms):                           32.55
Median ITL (ms):                         32.10
P99 ITL (ms):                            39.80
==================================================

w/ the change
============ Serving Benchmark Result ============
Successful requests:                     160
Benchmark duration (s):                  129.66
Total input tokens:                      160000
Total generated tokens:                  132622
Request throughput (req/s):              1.23
Output token throughput (tok/s):         1022.84
Total Token throughput (tok/s):          2256.84
---------------Time to First Token----------------
Mean TTFT (ms):                          76.96
Median TTFT (ms):                        68.83
P99 TTFT (ms):                           131.34
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          27.16
Median TPOT (ms):                        27.13
P99 TPOT (ms):                           28.05
---------------Inter-token Latency----------------
Mean ITL (ms):                           27.14
Median ITL (ms):                         26.64
P99 ITL (ms):                            34.28
==================================================

Generation Test

16E

Prompt: 'The color of the sky is blue but sometimes it can also be', Generated text: " gray or brown due to pollution. The color of the sky is determined by the way that light scatters off of the particles in the Earth's atmosphere. The blue color of the sky is caused by a phenomenon called Rayleigh scattering, which is the scattering of light by small particles much smaller than the wavelength of the light. This scattering is more effective for shorter wavelengths of light, such as blue and violet, than for longer wavelengths, such as red and orange. As a result, the blue light is scattered in all directions and reaches our eyes from all parts of the sky, giving it its blue color. The color of the sky can also be affected by the presence of particles in the atmosphere, such as dust, water vapor, and pollutants. These particles can scatter light in different ways, which can change the apparent color of the sky. For example, during sunrise and sunset, the sky can take on hues of red and orange due to the scattering of light by atmospheric particles. Similarly, the presence of pollutants and dust in the atmosphere can cause the sky to appear gray or brown. So, while the color of the sky is typically blue, it can vary depending on the conditions in the atmosphere. \nThe best answer is blue."
Prompt: 'The capital of France is', Generated text: " Paris. Paris is known as the City of Light and is famous for its art, fashion, and culture. The Eiffel Tower, a iconic symbol of Paris, was built for the 1889 World's Fair and was originally intended to be a temporary structure. However, it has become a beloved landmark and a symbol of French culture.\n\nThe Louvre Museum, located in the heart of Paris, is one of the world's largest and most famous museums. It houses an impressive collection of art and artifacts from around the world, including the Mona Lisa. The museum's stunning glass pyramid entrance, designed by I.M. Pei, is a popular spot for tourists to take photos.\n\nIn addition to its rich history and cultural attractions, Paris is also known for its romantic atmosphere and beautiful architecture. The city's charming streets, picturesque bridges, and historic buildings make it a popular destination for couples and honeymooners.\n\nOverall, Paris is a must-visit destination for anyone interested in history, art, fashion, and culture. Its unique blend of traditional and modern attractions makes it a city that has something for everyone.\n\n### Key Facts about Paris:\n\n* Capital of France\n* Known as the City of Light\n* Famous for art, fashion, and culture\n* Home to the Eiffel Tower"
Prompt: 'What is batch inference?', Generated text: " \n Batch inference is a process in machine learning (ML) where a model is used to make predictions on a large dataset all at once, rather than one by one.  This approach is particularly useful when you need to process a big dataset and don't require real-time predictions. \n In batch inference, you typically:\n 1. Prepare your dataset: Collect and preprocess the data you want to make predictions on. \n 2. Load the model: Load a trained ML model into memory. \n 3. Make predictions: Use the model to generate predictions for the entire dataset in a single operation. \n 4. Post-process: Optionally, perform additional processing on the predictions, such as filtering or aggregating them.\n\n Batch inference offers several advantages, including:\n - **Efficiency**: Processing large datasets in batches can be more efficient than making individual predictions, especially when working with limited computational resources. \n - **Scalability**: Batch inference allows you to handle large datasets that might be too big to process one by one. \n - **Cost-effectiveness**: In cloud-based ML services, batch inference can be more cost-effective than real-time prediction services, which often charge per request.\n\n However, batch inference also has some limitations:\n - **Latency**: Since predictions are made in batches, there might"

128E

Prompt: 'The color of the sky is blue but sometimes it can also be', Generated text: ' red, orange or grey. What is the reason behind this? - My Science School\nThe color of the sky is blue but sometimes it can also be red, orange or grey. What is the reason behind this?\nThe color of the sky is blue because the white light coming from the Sun is scattered by the molecules of gases present in the atmosphere. The blue color has the shortest wavelength and is scattered the most. That is why the sky appears blue to us most of the time. \nThe color of the sky changes depending on the atmospheric conditions. If the atmosphere is filled with dust and other particles, the light gets scattered in different ways, changing the color of the sky. \nDuring sunrise and sunset, the sky often turns red or orange. This is because during these times, the sunlight has to travel through a thicker layer of atmosphere to reach our eyes. The shorter blue wavelengths are scattered away, leaving mainly the longer red and orange wavelengths to reach our eyes, giving the sky its reddish hue. \nOn cloudy or foggy days, the sky can appear grey. This is because the clouds or fog reflect and scatter the sunlight in all directions, making the sky appear grey or white. \nIn summary, the color of the sky is influenced by the scattering of'
Prompt: 'The capital of France is', Generated text: ' Paris. Paris is the most populous city in France, with an estimated population of over 2.1 million people within its administrative limits. The city is a global center for art, fashion, cuisine, and culture. \nParis is home to many famous landmarks, including the Eiffel Tower, the Louvre Museum, and Notre Dame Cathedral. The city is also known for its romantic atmosphere and is a popular destination for tourists. \nThe history of Paris dates back to the 3rd century BC, when it was founded by the Celtic tribe known as the Parisii. The city has been an important center of politics, culture, and commerce for centuries, and has played a significant role in the history of Europe. \nToday, Paris is a modern and vibrant city, with a diverse economy and a high standard of living. The city is home to many international organizations, including the United Nations Educational, Scientific and Cultural Organization (UNESCO) and the International Chamber of Commerce. \nParis is also known for its cuisine, which is considered to be among the best in the world. The city is home to many Michelin-starred restaurants, and is famous for its pastries, cheeses, and wines. \nIn addition to its cultural and culinary attractions, Paris is also'
Prompt: 'What is batch inference?', Generated text: ' - Inferless Inferless home page light logo dark logo Search or ask...\nWhat is batch inference?\nWhat is batch inference?\nBatch inference is a technique used in machine learning (ML) to process multiple inputs or data points simultaneously, rather than one at a time. This approach is particularly useful when you have a large dataset and need to make predictions or inferences on it. Instead of sending each data point through the model individually, you group them into batches and process the entire batch at once.\nHow batch inference works\n0. Data Preparation: The data you want to make predictions on is prepared and organized into batches. Each batch contains a fixed number of data points (e.g., images, text, numerical data).\n1. Model Inference: The ML model processes the entire batch of data in one go. The model performs the necessary computations on all data points in the batch simultaneously.\n2. Output Generation: The model generates predictions or inferences for all data points in the batch. The output is typically a tensor or array containing the predictions for each data point.\nBenefits of batch inference\n0. Efficiency: Batch inference can significantly improve the efficiency of your ML workflows. By processing multiple data points at once, you can reduce the overhead associated with individual inference requests, such as data transfer and'

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@sijiac sijiac marked this pull request as ready for review April 15, 2025 18:54
@hongxiayang
Copy link
Collaborator

Thanks for the PR. We will validate @sijiac @houseroad .

Copy link
Collaborator

@hongxiayang hongxiayang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We need to enable the graph mode in order to achieve the full performance benefit comparing with V1 graph mode without aiter.

@tjtanaa
Copy link
Contributor

tjtanaa commented Apr 16, 2025

@sijiac @hongxiayang we will enable for V1 after this PR is merged.

@tjtanaa
Copy link
Contributor

tjtanaa commented Apr 16, 2025

As a supplementary information this PR:

The GSM8K lmeval score of AITER kernel of meta-llama/Llama-4-Scout-17B-16E-Instruct:

vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,max_model_len=30000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9181|±  |0.0076|
|     |       |strict-match    |     5|exact_match|↑  |0.9022|±  |0.0082|

The GSM8K lmeval score of AITER kernel of meta-llama/Llama-4-Maverick-17B-128E-Instruct:

vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct,tensor_parallel_size=8,max_model_len=30000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9242|±  |0.0073|
|     |       |strict-match    |     5|exact_match|↑  |0.9272|±  |0.0072|

@sijiac
Copy link
Contributor Author

sijiac commented Apr 16, 2025

cc @mgoin, could you take a look at the PR? is good to be merged?

@@ -39,6 +40,16 @@ def rocm_aiter_fused_experts(
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)

if apply_router_weight_on_input:
_, topk = topk_weights.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we assert topk_weights.dim() == 2?


hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does AITER require fp32 weight?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the passed-in topk_weights must in fp32 dtype, otherwise, it will have numeric issues

@hongxiayang hongxiayang added rocm Related to AMD ROCm ready ONLY add when PR is ready to merge/full CI is needed labels Apr 16, 2025
@houseroad
Copy link
Collaborator

fix the linter?

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@houseroad houseroad merged commit 92edf35 into vllm-project:main Apr 17, 2025
43 of 44 checks passed
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants