Skip to content

Add support more NVIDIA devices #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

johnnynunez
Copy link

@johnnynunez johnnynunez commented Apr 22, 2025

Support:

Jetson Orin: 8.7
Jetson Thor: 10.1
Blackwell B100/B200/GB200: 10.0
Spark: 11.0

@woct0rdho
Copy link

Did you test that each arch/device should be routed to which implementation in sageattn in core.py?

Maybe we should eventually implement something like autotune to do this

@johnnynunez
Copy link
Author

johnnynunez commented Apr 22, 2025

Did you test that each arch/device should be routed to which implementation in sageattn in core.py?

Maybe we should eventually implement something like autotune to do this

I'm testing with Ada, Hopper (gh200) and jetson orin and rtx5090/gb200

@johnnynunez
Copy link
Author

For blackwell is the same as rtx50 with triton 3.3.x

@pftq
Copy link

pftq commented May 17, 2025

Isn't there more needed to handle the B200? The commit seems to only get past the setup process. For example, the sm100 for B200 is not a case handled in the core.py (it skips to sm120).

Line 135 in core.py

    elif arch == "sm90":
        return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
    elif arch == "sm120":
        return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.

Otherwise seems to throw this error:

  File "/workspace/ComfyUI/venv/lib/python3.11/site-packages/sageattention/core.py", line 138, in sageattn
    raise ValueError(f"Unsupported CUDA architecture: {arch}")
ValueError: Unsupported CUDA architecture: sm100

Copying one of the other cases doesn't seem to be enough:

 elif arch == "sm100":
        return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")

Still results in:

  File "/workspace/ComfyUI/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 857, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/ComfyUI/venv/lib/python3.11/site-packages/sageattention/core.py", line 722, in sageattn_qk_int8_pv_fp8_cuda
    o = torch.empty(q.size(), dtype=dtype, device=q.device)
RuntimeError: CUDA error: no kernel image is available for execution on the device

@woct0rdho
Copy link

woct0rdho commented May 17, 2025

@pftq Ok let's do it. As you have a B200, you can test which of the implementations is the fastest (maybe using the scripts at https://github.com/thu-ml/SageAttention/tree/main/bench ) and route it in the function sageattn.

@pftq
Copy link

pftq commented May 17, 2025

See the bottom of my earlier reply - right now just getting a CUDA error so I'd need to get past that first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants