Skip to content

Commit 585c765

Browse files
committed
fix logits tests
Signed-off-by: Austin Liu <[email protected]>
1 parent 8791f16 commit 585c765

File tree

2 files changed

+168
-168
lines changed

2 files changed

+168
-168
lines changed

src/liger_kernel/transformers/model/llama.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.nn.functional as F
99

1010
from torch.nn import CrossEntropyLoss
11-
from torch.nn.attention.flex_attention import create_block_mask
1211
from torch.nn.attention.flex_attention import flex_attention
1312
from transformers.modeling_outputs import CausalLMOutputWithPast
1413
from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
@@ -256,8 +255,6 @@ def lce_forward(
256255

257256

258257
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py#L12
259-
260-
261258
def flex_attention_forward(
262259
module: torch.nn.Module,
263260
query: torch.Tensor,
@@ -279,26 +276,24 @@ def causal_mod(score, b, h, q_idx, kv_idx):
279276
score = score + causal_mask[b][0][q_idx][kv_idx]
280277
return score
281278

282-
# We only got `attention_mask` tensors, so we recreate `causal_mask` function as specific llama causal attention
283-
# TODO: Consider other customized `attention_mask` in the future, e.g., shared prefix
284-
def causal_mask_fn(b, h, q_idx, kv_idx):
285-
return q_idx >= kv_idx
279+
# def causal_mask_fn(b, h, q_idx, kv_idx):
280+
# return q_idx >= kv_idx
286281

287-
# To construct block attention mask that leverages sparsity.
288-
sparse_causal_mask = create_block_mask(causal_mask_fn, None, None, query.shape[-2], query.shape[-2], device="cuda")
282+
# TODO: Construct block attention mask that leverages sparsity
283+
# sparse_causal_mask = create_block_mask(
284+
# causal_mask_fn, B=None, H=None, Q_LEN=query.shape[-2], KV_LEN=key.shape[-2], device=query.device, BLOCK_SIZE=1
285+
# )
289286

290287
attn_output, attention_weights = flex_attention(
291288
query,
292289
key,
293290
value,
294291
score_mod=causal_mod,
295-
block_mask=sparse_causal_mask,
292+
# block_mask=sparse_causal_mask,
296293
enable_gqa=True,
297294
scale=scaling,
298-
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
299-
# For simplification, we thus always return it as no additional computations are introduced.
300295
return_lse=True,
301-
kernel_options={ # different harware might need different configs
296+
kernel_options={
302297
"BLOCK_M": 32,
303298
"BLOCK_N": 32,
304299
"BLOCK_M1": 16,
@@ -307,7 +302,7 @@ def causal_mask_fn(b, h, q_idx, kv_idx):
307302
"BLOCK_N2": 16,
308303
},
309304
)
310-
# lse is returned in float32
305+
311306
attention_weights = attention_weights.to(value.dtype)
312307
attn_output = attn_output.transpose(1, 2).contiguous()
313308

0 commit comments

Comments
 (0)