|
4 | 4 | # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
5 | 5 | ###############################################################################
|
6 | 6 |
|
7 |
| -import os |
8 | 7 | from dataclasses import dataclass
|
9 | 8 | from typing import Any, Dict, List, Optional, Tuple, Type
|
10 | 9 |
|
@@ -200,7 +199,8 @@ def forward(
|
200 | 199 | key_cache = None
|
201 | 200 | value_cache = None
|
202 | 201 | if attn_metadata.is_prompt and self.attn_type \
|
203 |
| - is not AttentionType.ENCODER_ONLY: |
| 202 | + is not AttentionType.ENCODER_ONLY \ |
| 203 | + and attn_metadata.block_list is None: |
204 | 204 | key = key.unflatten(0, (block_indices.size(0), -1))
|
205 | 205 | value = value.unflatten(0, (block_indices.size(0), -1))
|
206 | 206 | if kv_cache is not None and isinstance(kv_cache, tuple):
|
@@ -248,24 +248,14 @@ def forward(
|
248 | 248 | # Decoding run.
|
249 | 249 | output = HPUPagedAttention.forward_decode(
|
250 | 250 | query=query,
|
251 |
| - key_cache=key_cache, |
252 |
| - value_cache=value_cache, |
253 |
| - block_list=attn_metadata.block_list, |
254 | 251 | block_mapping=attn_metadata.block_mapping,
|
255 | 252 | block_bias=attn_metadata.attn_bias,
|
256 |
| - block_scales=attn_metadata.block_scales, |
257 | 253 | block_groups=attn_metadata.block_groups,
|
258 |
| - scale=self.scale, |
259 |
| - matmul_qk_op=self.matmul_qk, |
260 |
| - matmul_av_op=self.matmul_av, |
261 |
| - batch2block_matmul_op=self.batch2block_matmul, |
262 |
| - block2batch_matmul_op=self.block2batch_matmul, |
263 |
| - keys_fetch_func=self.k_cache.fetch_from_cache, |
264 |
| - values_fetch_func=self.v_cache.fetch_from_cache) |
| 254 | + **self.common_attention_args(attn_metadata.block_list, |
| 255 | + key_cache, value_cache)) |
265 | 256 | # Reshape the output tensor.
|
266 | 257 | return output.view(batch_size, seq_len, hidden_size)
|
267 | 258 |
|
268 |
| - |
269 | 259 | def common_attention_args(self,
|
270 | 260 | block_list=None,
|
271 | 261 | key_cache=None,
|
|
0 commit comments