@@ -243,6 +243,7 @@ defmodule Bumblebee.Text.Generation do
243
243
prepare_inputs_fun = fn inputs , params ->
244
244
encoder_outputs = encoder_predict_fun . ( params , inputs )
245
245
246
+ padded_batch_item? = padded_batch_item? ( encoder_input ( inputs ) )
246
247
batch_size = Nx . axis_size ( encoder_input ( inputs ) , 0 )
247
248
248
249
inputs = Map . put ( inputs , "encoder_hidden_state" , encoder_outputs . hidden_state )
@@ -254,18 +255,19 @@ defmodule Bumblebee.Text.Generation do
254
255
255
256
max_length = max_length_fun . ( 1 )
256
257
inputs = prepare_decoder_inputs ( inputs , "decoder_" , spec , model , max_length )
257
- { inputs , inputs [ "decoder_input_ids" ] , max_length }
258
+ { inputs , inputs [ "decoder_input_ids" ] , padded_batch_item? , max_length }
258
259
end
259
260
260
261
update_inputs_fun = & update_decoder_inputs ( "decoder_" , & 1 , & 2 , & 3 )
261
262
262
263
{ prepare_inputs_fun , update_inputs_fun }
263
264
else
264
265
prepare_inputs_fun = fn inputs , _params ->
266
+ padded_batch_item? = padded_batch_item? ( inputs [ "input_ids" ] )
265
267
sequence_length = Nx . axis_size ( inputs [ "input_ids" ] , 1 )
266
268
max_length = max_length_fun . ( sequence_length )
267
269
inputs = prepare_decoder_inputs ( inputs , "" , spec , model , max_length )
268
- { inputs , inputs [ "input_ids" ] , max_length }
270
+ { inputs , inputs [ "input_ids" ] , padded_batch_item? , max_length }
269
271
end
270
272
271
273
update_inputs_fun = & update_decoder_inputs ( "" , & 1 , & 2 , & 3 )
@@ -283,6 +285,13 @@ defmodule Bumblebee.Text.Generation do
283
285
inputs [ "input_ids" ] || inputs [ "input_features" ] || inputs [ "pixel_values" ]
284
286
end
285
287
288
+ defp padded_batch_item? ( input ) do
289
+ [ _ | non_batch_axes ] = Nx . axes ( input )
290
+ # We check each batch item if it is full of zeros, in which case
291
+ # case we assume it's padding, not an actual input.
292
+ input |> Nx . equal ( 0 ) |> Nx . all ( axes: non_batch_axes )
293
+ end
294
+
286
295
defp prepare_decoder_inputs ( inputs , prefix , spec , model , max_length ) do
287
296
input_ids = inputs [ prefix <> "input_ids" ]
288
297
attention_mask = inputs [ prefix <> "attention_mask" ] || Nx . broadcast ( 1 , input_ids )
@@ -396,7 +405,8 @@ defmodule Bumblebee.Text.Generation do
396
405
) do
397
406
{ seed , inputs } = pop_seed ( inputs )
398
407
399
- { decoder_inputs , decoder_input_ids , max_length } = prepare_inputs_fun . ( inputs , params )
408
+ { decoder_inputs , decoder_input_ids , padded_batch_item? , max_length } =
409
+ prepare_inputs_fun . ( inputs , params )
400
410
401
411
length = Nx . axis_size ( decoder_input_ids , 1 )
402
412
@@ -414,6 +424,7 @@ defmodule Bumblebee.Text.Generation do
414
424
greedy (
415
425
decoder_inputs ,
416
426
decoder_input_ids ,
427
+ padded_batch_item? ,
417
428
predict_fun ,
418
429
params ,
419
430
logits_processor_fun ,
@@ -425,6 +436,7 @@ defmodule Bumblebee.Text.Generation do
425
436
contrastive (
426
437
decoder_inputs ,
427
438
decoder_input_ids ,
439
+ padded_batch_item? ,
428
440
predict_fun ,
429
441
params ,
430
442
logits_processor_fun ,
@@ -440,6 +452,7 @@ defmodule Bumblebee.Text.Generation do
440
452
sampling (
441
453
decoder_inputs ,
442
454
decoder_input_ids ,
455
+ padded_batch_item? ,
443
456
predict_fun ,
444
457
params ,
445
458
seed ,
@@ -469,6 +482,7 @@ defmodule Bumblebee.Text.Generation do
469
482
defnp greedy (
470
483
inputs ,
471
484
decoder_input_ids ,
485
+ padded_batch_item? ,
472
486
predict_fun ,
473
487
params ,
474
488
logits_processor_fun ,
@@ -479,7 +493,7 @@ defmodule Bumblebee.Text.Generation do
479
493
pad_token_id = opts [ :pad_token_id ]
480
494
eos_token_id = opts [ :eos_token_id ]
481
495
482
- state = init_sequences ( decoder_input_ids , max_length , pad_token_id )
496
+ state = init_sequences ( decoder_input_ids , padded_batch_item? , max_length , pad_token_id )
483
497
484
498
# The loop works with inputs of length 1, so if the initial input
485
499
# is longer, we make the initial pass outside
@@ -519,15 +533,17 @@ defmodule Bumblebee.Text.Generation do
519
533
state
520
534
end
521
535
522
- defnp init_sequences ( decoder_input_ids , max_length , pad_token_id ) do
536
+ defnp init_sequences ( decoder_input_ids , padded_batch_item? , max_length , pad_token_id ) do
523
537
{ batch_size , length } = Nx . shape ( decoder_input_ids )
524
538
525
539
sequences = Nx . broadcast ( pad_token_id , { batch_size , max_length } )
526
540
sequences = Nx . put_slice ( sequences , [ 0 , 0 ] , decoder_input_ids )
527
541
528
542
# For each sequence, we keep track of its final length, where 0
529
- # means that it has not been finished yet
530
- finished_length = Nx . broadcast ( 0 , { batch_size } )
543
+ # means that it has not been finished yet. If there are padding
544
+ # batch inputs, we immediately mark them as finished, otherwise
545
+ # they could produce arbitrary tokens until we reach max length.
546
+ finished_length = Nx . select ( padded_batch_item? , 1 , 0 )
531
547
532
548
% {
533
549
sequences: sequences ,
@@ -631,6 +647,7 @@ defmodule Bumblebee.Text.Generation do
631
647
defnp contrastive (
632
648
inputs ,
633
649
decoder_input_ids ,
650
+ padded_batch_item? ,
634
651
predict_fun ,
635
652
params ,
636
653
logits_processor_fun ,
@@ -644,7 +661,7 @@ defmodule Bumblebee.Text.Generation do
644
661
top_k = opts [ :top_k ]
645
662
penalty_alpha = opts [ :penalty_alpha ]
646
663
647
- state = init_sequences ( decoder_input_ids , max_length , pad_token_id )
664
+ state = init_sequences ( decoder_input_ids , padded_batch_item? , max_length , pad_token_id )
648
665
649
666
# Step (1)
650
667
# Initial pass to obtain hidden state and expand inputs to top-k
@@ -796,6 +813,7 @@ defmodule Bumblebee.Text.Generation do
796
813
defnp sampling (
797
814
inputs ,
798
815
decoder_input_ids ,
816
+ padded_batch_item? ,
799
817
predict_fun ,
800
818
params ,
801
819
seed ,
@@ -807,7 +825,7 @@ defmodule Bumblebee.Text.Generation do
807
825
pad_token_id = opts [ :pad_token_id ]
808
826
eos_token_id = opts [ :eos_token_id ]
809
827
810
- state = init_sequences ( decoder_input_ids , max_length , pad_token_id )
828
+ state = init_sequences ( decoder_input_ids , padded_batch_item? , max_length , pad_token_id )
811
829
812
830
prng_key = seed |> Nx . vectorize ( :batch ) |> Nx.Random . key ( )
813
831
0 commit comments