8
8
import torch .nn .functional as F
9
9
10
10
from torch .nn import CrossEntropyLoss
11
- from torch .nn .attention .flex_attention import create_block_mask
12
11
from torch .nn .attention .flex_attention import flex_attention
13
12
from transformers .modeling_outputs import CausalLMOutputWithPast
14
13
from transformers .models .llama .modeling_llama import _CONFIG_FOR_DOC
@@ -256,8 +255,6 @@ def lce_forward(
256
255
257
256
258
257
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py#L12
259
-
260
-
261
258
def flex_attention_forward (
262
259
module : torch .nn .Module ,
263
260
query : torch .Tensor ,
@@ -279,26 +276,24 @@ def causal_mod(score, b, h, q_idx, kv_idx):
279
276
score = score + causal_mask [b ][0 ][q_idx ][kv_idx ]
280
277
return score
281
278
282
- # We only got `attention_mask` tensors, so we recreate `causal_mask` function as specific llama causal attention
283
- # TODO: Consider other customized `attention_mask` in the future, e.g., shared prefix
284
- def causal_mask_fn (b , h , q_idx , kv_idx ):
285
- return q_idx >= kv_idx
279
+ # def causal_mask_fn(b, h, q_idx, kv_idx):
280
+ # return q_idx >= kv_idx
286
281
287
- # To construct block attention mask that leverages sparsity.
288
- sparse_causal_mask = create_block_mask (causal_mask_fn , None , None , query .shape [- 2 ], query .shape [- 2 ], device = "cuda" )
282
+ # TODO: Construct block attention mask that leverages sparsity
283
+ # sparse_causal_mask = create_block_mask(
284
+ # causal_mask_fn, B=None, H=None, Q_LEN=query.shape[-2], KV_LEN=key.shape[-2], device=query.device, BLOCK_SIZE=1
285
+ # )
289
286
290
287
attn_output , attention_weights = flex_attention (
291
288
query ,
292
289
key ,
293
290
value ,
294
291
score_mod = causal_mod ,
295
- block_mask = sparse_causal_mask ,
292
+ # block_mask=sparse_causal_mask,
296
293
enable_gqa = True ,
297
294
scale = scaling ,
298
- # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
299
- # For simplification, we thus always return it as no additional computations are introduced.
300
295
return_lse = True ,
301
- kernel_options = { # different harware might need different configs
296
+ kernel_options = {
302
297
"BLOCK_M" : 32 ,
303
298
"BLOCK_N" : 32 ,
304
299
"BLOCK_M1" : 16 ,
@@ -307,7 +302,7 @@ def causal_mask_fn(b, h, q_idx, kv_idx):
307
302
"BLOCK_N2" : 16 ,
308
303
},
309
304
)
310
- # lse is returned in float32
305
+
311
306
attention_weights = attention_weights .to (value .dtype )
312
307
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
313
308
0 commit comments