Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Commit 75bc3c9

Browse files
Fix SDPA in case attn_mask == None (#78)
1 parent 736afb5 commit 75bc3c9

File tree

4 files changed

+65
-9
lines changed

4 files changed

+65
-9
lines changed

intel_npu_acceleration_library/backend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .qlinear import QLinear
1313
from .tensor import Tensor
1414
from .factory import NNFactory
15-
from .sdpa import SDPA
15+
from .sdpa import SDPA, SimpleSDPA
1616
from .runtime import run_matmul, run_factory, clear_cache
1717

1818
check_npu_and_driver_version()
@@ -27,6 +27,7 @@
2727
"QLinear",
2828
"Convolution",
2929
"SDPA",
30+
"SimpleSDPA",
3031
"run_matmul",
3132
"run_factory",
3233
"clear_cache",

intel_npu_acceleration_library/backend/sdpa.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,50 @@ def run(
5858
np.ndarray: result
5959
"""
6060
return super().run(query, key, value, mask)
61+
62+
63+
class SimpleSDPA(NNFactory):
64+
"""Implementation of a ScaledDotProductAttention NPU operation."""
65+
66+
def __init__(
67+
self,
68+
query_shapes: Tuple[int, int],
69+
key_shapes: Tuple[int, int],
70+
value_shapes: Tuple[int, int],
71+
is_causal: bool = False,
72+
profile: bool = False,
73+
device: str = "NPU",
74+
):
75+
"""Initialize the SDPA.
76+
77+
Args:
78+
query_shapes (Tuple[int, int]): shape of the query tensor
79+
key_shapes (Tuple[int, int]): shape of the key tensor
80+
value_shapes (Tuple[int, int]): shape of the value tensor
81+
is_causal (bool, optional): If the SDPA mask is is_causal or not. Defaults to False.
82+
profile (bool, optional): Enable/Disable profiling. Defaults to False.
83+
device (str, optional): Target device, default to "NPU".
84+
"""
85+
super().__init__(profile, device)
86+
87+
self.query = self.parameter(query_shapes)
88+
self.key = self.parameter(key_shapes)
89+
self.value = self.parameter(value_shapes)
90+
91+
_ = self.scaled_dot_product_attention_simple( # type: ignore[attr-defined]
92+
self.query, self.key, self.value, is_causal
93+
)
94+
self.compile()
95+
96+
def run(self, query: np.ndarray, key: np.ndarray, value: np.ndarray) -> np.ndarray:
97+
"""Run the scaled dot product attention kernel.
98+
99+
Args:
100+
query (np.ndarray): sdpa query tensor
101+
key (np.ndarray): sdpa key tensor
102+
value (np.ndarray): sdpa value tensor
103+
104+
Returns:
105+
np.ndarray: result
106+
"""
107+
return super().run(query, key, value)

intel_npu_acceleration_library/functional/scaled_dot_product_attention.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright © 2024 Intel Corporation
33
# SPDX-License-Identifier: Apache 2.0
44
#
5-
from intel_npu_acceleration_library.backend import run_factory, SDPA
5+
from intel_npu_acceleration_library.backend import run_factory, SDPA, SimpleSDPA
66
from typing import Optional
77
from functools import partial
88
import torch
@@ -34,10 +34,14 @@ def scaled_dot_product_attention(
3434
Returns:
3535
torch.Tensor: _description_
3636
"""
37-
backend_cls = partial(SDPA, is_causal=is_causal)
3837
if dropout_p != 0:
3938
raise RuntimeError("dropout_p != 0 is not supported yet")
4039
if scale is not None:
4140
raise RuntimeError("scale != 0 is not supported yet")
4241

43-
return run_factory([query, key, value, attn_mask], [], backend_cls)
42+
if attn_mask is None:
43+
backend_cls = partial(SimpleSDPA, is_causal=is_causal) # type: ignore
44+
return run_factory([query, key, value], [], backend_cls)
45+
else:
46+
backend_cls = partial(SDPA, is_causal=is_causal) # type: ignore
47+
return run_factory([query, key, value, attn_mask], [], backend_cls)

test/python/test_sdpa.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def test_sdpa(heads, sequence, dim, kv_cache, is_causal):
5959
@pytest.mark.parametrize("dim", [512, 1024])
6060
@pytest.mark.parametrize("kv_cache", [True, False])
6161
@pytest.mark.parametrize("is_causal", [False, True])
62-
def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal):
62+
@pytest.mark.parametrize("use_mask", [False, True])
63+
def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal, use_mask):
6364

6465
min_value = torch.finfo(torch.float16).min
6566

@@ -68,10 +69,13 @@ def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal):
6869
)
6970
key = torch.rand(1, heads, sequence, dim // heads).to(torch.float16)
7071
value = torch.rand(1, heads, sequence, dim // heads).to(torch.float16)
71-
mask = min_value * torch.ones(1, heads, 1 if kv_cache else sequence, sequence).to(
72-
torch.float16
73-
)
74-
mask = torch.triu(mask)
72+
if use_mask:
73+
mask = min_value * torch.ones(
74+
1, heads, 1 if kv_cache else sequence, sequence
75+
).to(torch.float16)
76+
mask = torch.triu(mask)
77+
else:
78+
mask = None
7579

7680
npu_result = scaled_dot_product_attention(
7781
query, key, value, mask, is_causal=is_causal

0 commit comments

Comments
 (0)