Skip to content

Commit fb4a65c

Browse files
committed
[Hardware][Intel-Gaudi] Update hpu-extension and update bucketing system for HPU device (vllm-project#17186)
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent 9f572e1 commit fb4a65c

File tree

5 files changed

+74
-297
lines changed

5 files changed

+74
-297
lines changed

requirements/hpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ numpy==1.26.4
99
tabulate
1010
setuptools>=61
1111
setuptools-scm>=8
12-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@62ad004
12+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624

vllm/attention/backends/hpu_attn.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
55
###############################################################################
66

7-
import os
87
from dataclasses import dataclass
98
from typing import Any, Dict, List, Optional, Tuple, Type
109

@@ -200,7 +199,8 @@ def forward(
200199
key_cache = None
201200
value_cache = None
202201
if attn_metadata.is_prompt and self.attn_type \
203-
is not AttentionType.ENCODER_ONLY:
202+
is not AttentionType.ENCODER_ONLY \
203+
and attn_metadata.block_list is None:
204204
key = key.unflatten(0, (block_indices.size(0), -1))
205205
value = value.unflatten(0, (block_indices.size(0), -1))
206206
if kv_cache is not None and isinstance(kv_cache, tuple):
@@ -248,24 +248,14 @@ def forward(
248248
# Decoding run.
249249
output = HPUPagedAttention.forward_decode(
250250
query=query,
251-
key_cache=key_cache,
252-
value_cache=value_cache,
253-
block_list=attn_metadata.block_list,
254251
block_mapping=attn_metadata.block_mapping,
255252
block_bias=attn_metadata.attn_bias,
256-
block_scales=attn_metadata.block_scales,
257253
block_groups=attn_metadata.block_groups,
258-
scale=self.scale,
259-
matmul_qk_op=self.matmul_qk,
260-
matmul_av_op=self.matmul_av,
261-
batch2block_matmul_op=self.batch2block_matmul,
262-
block2batch_matmul_op=self.block2batch_matmul,
263-
keys_fetch_func=self.k_cache.fetch_from_cache,
264-
values_fetch_func=self.v_cache.fetch_from_cache)
254+
**self.common_attention_args(attn_metadata.block_list,
255+
key_cache, value_cache))
265256
# Reshape the output tensor.
266257
return output.view(batch_size, seq_len, hidden_size)
267258

268-
269259
def common_attention_args(self,
270260
block_list=None,
271261
key_cache=None,

vllm/attention/ops/hpu_paged_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata:
2222
block_usage: Optional[torch.Tensor]
2323
block_indices: Optional[torch.Tensor]
2424
block_offsets: Optional[torch.Tensor]
25-
block_scales: Optional[torch.Tensor]
2625
block_groups: Optional[torch.Tensor]
2726

2827

0 commit comments

Comments
 (0)