File tree Expand file tree Collapse file tree 2 files changed +10
-6
lines changed
src/transformers/models/gemma3 Expand file tree Collapse file tree 2 files changed +10
-6
lines changed Original file line number Diff line number Diff line change @@ -361,13 +361,15 @@ def forward(
361361 )
362362 else :
363363 attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
364-
364+ if attention_mask is not None :
365+ # backwards compatibility
366+ attention_mask = attention_mask .to (query_states )
365367 attn_output , attn_weights = attention_interface (
366368 self ,
367369 query_states ,
368370 key_states ,
369371 value_states ,
370- attention_mask . to ( query_states ) ,
372+ attention_mask ,
371373 dropout = self .attention_dropout if self .training else 0.0 ,
372374 scaling = self .scaling ,
373375 sliding_window = self .sliding_window ,
@@ -1360,7 +1362,7 @@ def forward(
13601362 ** lm_kwargs ,
13611363 )
13621364
1363- logits = outputs . logits
1365+ logits = outputs [ 0 ]
13641366 loss = None
13651367 if labels is not None :
13661368 # Upcast to float if we need to compute the loss to avoid potential precision issues
Original file line number Diff line number Diff line change @@ -418,13 +418,15 @@ def forward(
418418 )
419419 else :
420420 attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
421-
421+ if attention_mask is not None :
422+ # backwards compatibility
423+ attention_mask = attention_mask .to (query_states )
422424 attn_output , attn_weights = attention_interface (
423425 self ,
424426 query_states ,
425427 key_states ,
426428 value_states ,
427- attention_mask . to ( query_states ) ,
429+ attention_mask ,
428430 dropout = self .attention_dropout if self .training else 0.0 ,
429431 scaling = self .scaling ,
430432 sliding_window = self .sliding_window ,
@@ -974,7 +976,7 @@ def forward(
974976 ** lm_kwargs ,
975977 )
976978
977- logits = outputs . logits
979+ logits = outputs [ 0 ]
978980 loss = None
979981 if labels is not None :
980982 # Upcast to float if we need to compute the loss to avoid potential precision issues
You can’t perform that action at this time.
0 commit comments