Skip to content

Commit 9a48c5b

Browse files
yma11dbyoung18
authored andcommitted
[XPU][Bugfix] minor fix for XPU (vllm-project#15591)
Signed-off-by: yan ma <[email protected]>
1 parent 93f1056 commit 9a48c5b

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

docs/source/getting_started/installation/gpu/xpu.inc.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Currently, there are no pre-built XPU wheels.
2323
- Second, install Python packages for vLLM XPU backend building:
2424

2525
```console
26+
git clone https://github.com/vllm-project/vllm.git
27+
cd vllm
2628
pip install --upgrade pip
2729
pip install -v -r requirements/xpu.txt
2830
```

vllm/attention/backends/ipex_attn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def forward(
220220
value_cache,
221221
attn_metadata.slot_mapping.flatten(),
222222
self.kv_cache_dtype,
223-
layer._k_scale,
224-
layer._v_scale,
223+
layer._k_scale_float,
224+
layer._v_scale_float,
225225
)
226226

227227
if attn_metadata.is_prompt:
@@ -306,8 +306,8 @@ def forward(
306306
max_seq_len,
307307
self.alibi_slopes,
308308
self.kv_cache_dtype,
309-
layer._k_scale,
310-
layer._v_scale,
309+
layer._k_scale_float,
310+
layer._v_scale_float,
311311
)
312312
else:
313313
# Run PagedAttention V2.
@@ -339,8 +339,8 @@ def forward(
339339
max_seq_len,
340340
self.alibi_slopes,
341341
self.kv_cache_dtype,
342-
layer._k_scale,
343-
layer._v_scale,
342+
layer._k_scale_float,
343+
layer._v_scale_float,
344344
)
345345

346346
# Reshape the output tensor.

0 commit comments

Comments
 (0)