Skip to content

Commit 7617ff5

Browse files
Fix padded batch items lengthening generation time (#419)
1 parent dd59194 commit 7617ff5

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

lib/bumblebee/text/generation.ex

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ defmodule Bumblebee.Text.Generation do
243243
prepare_inputs_fun = fn inputs, params ->
244244
encoder_outputs = encoder_predict_fun.(params, inputs)
245245

246+
padded_batch_item? = padded_batch_item?(encoder_input(inputs))
246247
batch_size = Nx.axis_size(encoder_input(inputs), 0)
247248

248249
inputs = Map.put(inputs, "encoder_hidden_state", encoder_outputs.hidden_state)
@@ -254,18 +255,19 @@ defmodule Bumblebee.Text.Generation do
254255

255256
max_length = max_length_fun.(1)
256257
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, model, max_length)
257-
{inputs, inputs["decoder_input_ids"], max_length}
258+
{inputs, inputs["decoder_input_ids"], padded_batch_item?, max_length}
258259
end
259260

260261
update_inputs_fun = &update_decoder_inputs("decoder_", &1, &2, &3)
261262

262263
{prepare_inputs_fun, update_inputs_fun}
263264
else
264265
prepare_inputs_fun = fn inputs, _params ->
266+
padded_batch_item? = padded_batch_item?(inputs["input_ids"])
265267
sequence_length = Nx.axis_size(inputs["input_ids"], 1)
266268
max_length = max_length_fun.(sequence_length)
267269
inputs = prepare_decoder_inputs(inputs, "", spec, model, max_length)
268-
{inputs, inputs["input_ids"], max_length}
270+
{inputs, inputs["input_ids"], padded_batch_item?, max_length}
269271
end
270272

271273
update_inputs_fun = &update_decoder_inputs("", &1, &2, &3)
@@ -283,6 +285,13 @@ defmodule Bumblebee.Text.Generation do
283285
inputs["input_ids"] || inputs["input_features"] || inputs["pixel_values"]
284286
end
285287

288+
defp padded_batch_item?(input) do
289+
[_ | non_batch_axes] = Nx.axes(input)
290+
# We check each batch item if it is full of zeros, in which case
291+
# case we assume it's padding, not an actual input.
292+
input |> Nx.equal(0) |> Nx.all(axes: non_batch_axes)
293+
end
294+
286295
defp prepare_decoder_inputs(inputs, prefix, spec, model, max_length) do
287296
input_ids = inputs[prefix <> "input_ids"]
288297
attention_mask = inputs[prefix <> "attention_mask"] || Nx.broadcast(1, input_ids)
@@ -396,7 +405,8 @@ defmodule Bumblebee.Text.Generation do
396405
) do
397406
{seed, inputs} = pop_seed(inputs)
398407

399-
{decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params)
408+
{decoder_inputs, decoder_input_ids, padded_batch_item?, max_length} =
409+
prepare_inputs_fun.(inputs, params)
400410

401411
length = Nx.axis_size(decoder_input_ids, 1)
402412

@@ -414,6 +424,7 @@ defmodule Bumblebee.Text.Generation do
414424
greedy(
415425
decoder_inputs,
416426
decoder_input_ids,
427+
padded_batch_item?,
417428
predict_fun,
418429
params,
419430
logits_processor_fun,
@@ -425,6 +436,7 @@ defmodule Bumblebee.Text.Generation do
425436
contrastive(
426437
decoder_inputs,
427438
decoder_input_ids,
439+
padded_batch_item?,
428440
predict_fun,
429441
params,
430442
logits_processor_fun,
@@ -440,6 +452,7 @@ defmodule Bumblebee.Text.Generation do
440452
sampling(
441453
decoder_inputs,
442454
decoder_input_ids,
455+
padded_batch_item?,
443456
predict_fun,
444457
params,
445458
seed,
@@ -469,6 +482,7 @@ defmodule Bumblebee.Text.Generation do
469482
defnp greedy(
470483
inputs,
471484
decoder_input_ids,
485+
padded_batch_item?,
472486
predict_fun,
473487
params,
474488
logits_processor_fun,
@@ -479,7 +493,7 @@ defmodule Bumblebee.Text.Generation do
479493
pad_token_id = opts[:pad_token_id]
480494
eos_token_id = opts[:eos_token_id]
481495

482-
state = init_sequences(decoder_input_ids, max_length, pad_token_id)
496+
state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id)
483497

484498
# The loop works with inputs of length 1, so if the initial input
485499
# is longer, we make the initial pass outside
@@ -519,15 +533,17 @@ defmodule Bumblebee.Text.Generation do
519533
state
520534
end
521535

522-
defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do
536+
defnp init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) do
523537
{batch_size, length} = Nx.shape(decoder_input_ids)
524538

525539
sequences = Nx.broadcast(pad_token_id, {batch_size, max_length})
526540
sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids)
527541

528542
# For each sequence, we keep track of its final length, where 0
529-
# means that it has not been finished yet
530-
finished_length = Nx.broadcast(0, {batch_size})
543+
# means that it has not been finished yet. If there are padding
544+
# batch inputs, we immediately mark them as finished, otherwise
545+
# they could produce arbitrary tokens until we reach max length.
546+
finished_length = Nx.select(padded_batch_item?, 1, 0)
531547

532548
%{
533549
sequences: sequences,
@@ -631,6 +647,7 @@ defmodule Bumblebee.Text.Generation do
631647
defnp contrastive(
632648
inputs,
633649
decoder_input_ids,
650+
padded_batch_item?,
634651
predict_fun,
635652
params,
636653
logits_processor_fun,
@@ -644,7 +661,7 @@ defmodule Bumblebee.Text.Generation do
644661
top_k = opts[:top_k]
645662
penalty_alpha = opts[:penalty_alpha]
646663

647-
state = init_sequences(decoder_input_ids, max_length, pad_token_id)
664+
state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id)
648665

649666
# Step (1)
650667
# Initial pass to obtain hidden state and expand inputs to top-k
@@ -796,6 +813,7 @@ defmodule Bumblebee.Text.Generation do
796813
defnp sampling(
797814
inputs,
798815
decoder_input_ids,
816+
padded_batch_item?,
799817
predict_fun,
800818
params,
801819
seed,
@@ -807,7 +825,7 @@ defmodule Bumblebee.Text.Generation do
807825
pad_token_id = opts[:pad_token_id]
808826
eos_token_id = opts[:eos_token_id]
809827

810-
state = init_sequences(decoder_input_ids, max_length, pad_token_id)
828+
state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id)
811829

812830
prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key()
813831

0 commit comments

Comments
 (0)