Skip to content

[V1] Move usage stats to worker and start logging TPU hardware #16211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 58 commits into from
Apr 25, 2025

Conversation

dyli-google
Copy link
Contributor

@dyli-google dyli-google commented Apr 7, 2025

This PR adds functionality to detect and report TPU hardware details when vLLM runs on a TPU platform. It uses torch_xla to gather the TPU count, type, and memory per device.

It also is moving GPU/CPU usage reporting from LLM Engine init to GPU/CPU worker init time

This addresses the requirement to track TPU usage for vLLM's data dashboards.

Copy link

github-actions bot commented Apr 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@dyli-google
Copy link
Contributor Author

@yarongmu-google
Copy link
Contributor

Thanks Daniel - LGTM. Have you tried to run this once locally? (Ack that local wouldn't be able to update any storage/dashboards etc etc but just to make sure we "attempted" to update these).

@yaochengji for review too

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 7, 2025
Comment on lines 180 to 192
if current_platform.is_tpu():
try:
import torch_xla.runtime as xr
from torch_xla.core import xla_model as xm
self.tpu_count = xr.world_size()
self.tpu_type = xm.xla_device_hw(xm.xla_device())
self.tpu_memory_per_device = xm.get_memory_info().bytes_limit
except ImportError:
logging.warning(
"torch_xla not found, skipping TPU usage statistics.")
self.tpu_count = None
self.tpu_type = None
self.tpu_memory_per_device = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just set gpu_count/gpu_type/gpu_memory_per_device. We can perform the disambiguation in backend processing. We can also silence the import error.

Please paste the output from ~/.config/vllm/usage_stats.json for verification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Where is ~/.config/vllm/usage_stats.json?

I cannot find it inside the docker:
root@t1v-n-a747908a-w-0:/workspace/vllm# cat ~/.config/vllm/usage_stats.json
cat: /root/.config/vllm/usage_stats.json: No such file or directory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, even without docker and building from source, still I cannot find ~/.config/vllm/usage_stats.json. Not sure how it is set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, after I started the server using python -m vllm.entrypoints.api_server, I got the json file (previously I used llm serve command).

It seems my code is not working: "gpu_count": null, "gpu_type": null, "gpu_memory_per_device": null

(myenv) dyli_google_com@t1v-n-b4c4da81-w-0:~/.config/vllm$ cat usage_stats.json
{"uuid": "d895fe5b-ff4c-42ca-a65e-deabc113a731", "provider": "GCP", "num_cpu": 180, "cpu_type": "AMD EPYC 9B14", "cpu_family_model_stepping": "25,17,1", "total_memory": 1521841610752, "architecture": "x86_64", "platform": "Linux-6.8.0-1015-gcp-x86_64-with-glibc2.35", "cuda_runtime": null, "gpu_count": null, "gpu_type": null, "gpu_memory_per_device": null, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": false, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.8.3", "context": "API_SERVER", "log_time": 1744157778107180032, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 1, "block_size": 16, "gpu_memory_utilization": 0.98, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": false, "enforce_eager": false, "disable_custom_all_reduce": true}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xm.xla_device_hw(xm.xla_device()) is not null in my TPU VM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

  1. Did you start the server using python -m vllm.entrypoints.api_server or llm serve?
  2. Did you use Docker, or just build from source?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't test it in vLLM or docker. Just directly use torch_xla in a naive python environment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like my code doesn't work:

(vllm) dyli_google_com@t1v-n-b4c4da81-w-0:~/vllm$ cat ~/.config/vllm/usage_stats.json
{"uuid": "484b129d-3e9f-466d-86f0-e6f8088af5c1", "provider": "GCP", "num_cpu": 180, "cpu_type": "AMD EPYC 9B14", "cpu_family_model_stepping": "25,17,1", "total_memory": 1521841610752, "architecture": "x86_64", "platform": "Linux-6.8.0-1015-gcp-x86_64-with-glibc2.35", "cuda_runtime": null, "gpu_count": null, "gpu_type": null, "gpu_memory_per_device": null, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": false, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.6.6.dev1916+g6a4eea4ff", "context": "OPENAI_API_SERVER", "log_time": 1744244819493581056, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 1, "block_size": 16, "gpu_memory_utilization": 0.95, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": null, "enforce_eager": false, "disable_custom_all_reduce": true}

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previous comment is blocking.

from torch_xla.core import xla_model as xm
self.gpu_count = xr.world_size()
self.gpu_type = xm.xla_device_hw(xm.xla_device())
self.gpu_memory_per_device = xm.get_memory_info().bytes_limit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xm.xla_device_hw(xm.xla_device()) return TPU as result.

Or do we want something like v6e, v5e?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yarongmu-google @simon-mo What do you think? I believe TPU should be OK?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Version number will be useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Simon.

@yaochengji Do we have ways to get the version number?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use torch_xla.tpu.get_tpu_type()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks. I just updated the code to use torch_xla.tpu.get_tpu_type()

@yaochengji
Copy link
Collaborator

BTW, could you change the title from [Hardware][Google] to [TPU]?

@mgoin mgoin changed the title [Hardware][Google] Track TPU usages in vLLM's data dashboards [TPU] Track TPU usages in vLLM's data dashboards Apr 9, 2025
@simon-mo
Copy link
Collaborator

This will be ready to merge once example JSON is posted. ideally on two different TPU machine.

@brittrock
Copy link

Please add a test for v5e and v6e please @dyli-google

Copy link

@brittrock brittrock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests needed to ensure this doesn't fail silently please

@dyli-google
Copy link
Contributor Author

The code is ready for review. Here are the testing results on TPU v6e, v5e and GPU H100.

1v6e

{"uuid": "91dfc112-fa3e-4c05-8b43-109274f51bfe", "provider": "GCP", "num_cpu": 180, "cpu_type": "AMD EPYC 9B14", "cpu_family_model_stepping": "25,17,1", "total_memory": 1521841610752, "architecture": "x86_64", "platform": "Linux-6.8.0-1015-gcp-x86_64-with-glibc2.35", "cuda_runtime": null, "gpu_count": 1, "gpu_type": "v6e", "gpu_memory_per_device": 33550237696, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": true, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.6.6.dev2243+gbc284db56.d20250424", "context": "ENGINE_CONTEXT", "log_time": 1745534206212480000, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 1, "block_size": 16, "gpu_memory_utilization": 0.95, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": true, "enforce_eager": false, "disable_custom_all_reduce": true}

8 v6e

{"uuid": "71802820-41c1-4561-9a0d-28254a3ba12b", "provider": "GCP", "num_cpu": 180, "cpu_type": "AMD EPYC 9B14", "cpu_family_model_stepping": "25,17,1", "total_memory": 1521841610752, "architecture": "x86_64", "platform": "Linux-6.8.0-1015-gcp-x86_64-with-glibc2.35", "cuda_runtime": null, "gpu_count": 8, "gpu_type": "v6e", "gpu_memory_per_device": 33550237696, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": true, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.6.6.dev2243+gbc284db56.d20250424", "context": "ENGINE_CONTEXT", "log_time": 1745535086707463936, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 8, "block_size": 16, "gpu_memory_utilization": 0.95, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": true, "enforce_eager": false, "disable_custom_all_reduce": true}

1 H100

{"uuid": "0c93d21c-5dc9-4f4f-ad05-ace2ca5ab17e", "provider": "GCP", "num_cpu": 104, "cpu_type": "Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz", "cpu_family_model_stepping": "6,143,8", "total_memory": 988784611328, "architecture": "x86_64", "platform": "Linux-5.10.0-33-cloud-amd64-x86_64-with-glibc2.31", "cuda_runtime": "12.4", "gpu_count": 1, "gpu_type": "NVIDIA H100 80GB HBM3", "gpu_memory_per_device": 84929347584, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": true, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.6.6.dev2250+g98e7ae08f.d20250425", "context": "ENGINE_CONTEXT", "log_time": 1745551578482265856, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 1, "block_size": 16, "gpu_memory_utilization": 0.95, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": true, "enforce_eager": false, "disable_custom_all_reduce": false}

4 H100s

{"uuid": "a65f27fb-d952-49ca-9e53-2a168a38a0cf", "provider": "GCP", "num_cpu": 104, "cpu_type": "Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz", "cpu_family_model_stepping": "6,143,8", "total_memory": 988784611328, "architecture": "x86_64", "platform": "Linux-5.10.0-33-cloud-amd64-x86_64-with-glibc2.31", "cuda_runtime": "12.4", "gpu_count": 4, "gpu_type": "NVIDIA H100 80GB HBM3", "gpu_memory_per_device": 84929347584, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": true, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.6.6.dev2250+g98e7ae08f.d20250425", "context": "ENGINE_CONTEXT", "log_time": 1745552092957214976, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 4, "block_size": 16, "gpu_memory_utilization": 0.95, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": true, "enforce_eager": false, "disable_custom_all_reduce": false

1 v5e

{"uuid": "d5dee4c4-e110-4882-b7c6-9329f427eca9", "provider": "GCP", "num_cpu": 224, "cpu_type": "AMD EPYC 7B13", "cpu_family_model_stepping": "25,1,", "total_memory": 405677613056, "architecture": "x86_64", "platform": "Linux-6.5.0-1013-gcp-x86_64-with-glibc2.35", "cuda_runtime": null, "gpu_count": 1, "gpu_type": "v5litepod", "gpu_memory_per_device": 16909336576, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": true, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.6.6.dev2250+g98e7ae08f", "context": "ENGINE_CONTEXT", "log_time": 1745555635935796992, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 1, "block_size": 16, "gpu_memory_utilization": 0.95, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": true, "enforce_eager": false, "disable_custom_all_reduce": true}

8 v5e

{"uuid": "fe8d226c-db26-453c-8df9-2817582bbe6c", "provider": "GCP", "num_cpu": 224, "cpu_type": "AMD EPYC 7B13", "cpu_family_model_stepping": "25,1,", "total_memory": 405677613056, "architecture": "x86_64", "platform": "Linux-6.5.0-1013-gcp-x86_64-with-glibc2.35", "cuda_runtime": null, "gpu_count": 8, "gpu_type": "v5litepod", "gpu_memory_per_device": 16909336576, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": true, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.6.6.dev2250+g98e7ae08f", "context": "ENGINE_CONTEXT", "log_time": 1745556057888128000, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 8, "block_size": 16, "gpu_memory_utilization": 0.95, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": true, "enforce_eager": false, "disable_custom_all_reduce": true}

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thank you for the iteration. I updated the description and title to describe moving the check to the worker

@mgoin mgoin changed the title [TPU] Track TPU usages in vLLM's data dashboards [V1] Move usage stats to worker and start logging TPU hardware Apr 25, 2025
@mgoin
Copy link
Member

mgoin commented Apr 25, 2025

Tested locally as well

@mgoin mgoin merged commit 48cb210 into vllm-project:main Apr 25, 2025
45 checks passed
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants