Skip to content

fix sdpa compatible problem #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sageattention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def sageattn(
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
"""
# NOTE: Ensure the sageattn API is compatible with the official SDPA.
# For example, in SDPA, we pass 'scale' as the softmax scaling factor.
if sm_scale is None and kwargs.get("scale", None) is not None:
assert isinstance(kwargs["scale"], float), "The scale must be a float."
sm_scale = kwargs["scale"]

arch = get_cuda_arch_versions()[q.device.index]
if arch == "sm80":
Expand Down Expand Up @@ -352,6 +357,12 @@ def sageattn_varlen(
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."

# NOTE: Ensure the sageattn API is compatible with the official SDPA.
# For example, in SDPA, we pass 'scale' as the softmax scaling factor.
if sm_scale is None and kwargs.get("scale", None) is not None:
assert isinstance(kwargs["scale"], float), "The scale must be a float."
sm_scale = kwargs["scale"]

# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
Expand Down