Skip to content

Commit 4e0e178

Browse files
Transfer serving computation result to binary backend upfront (#282)
1 parent 57bdcce commit 4e0e178

15 files changed

+32
-35
lines changed

lib/bumblebee/audio/speech_to_text_whisper.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
6767

6868
fn inputs ->
6969
inputs = Shared.maybe_pad(inputs, batch_size)
70-
generate_fun.(params, inputs)
70+
generate_fun.(params, inputs) |> Shared.serving_post_computation()
7171
end
7272
end,
7373
defn_options

lib/bumblebee/diffusion/stable_diffusion.ex

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
284284
%{image: image}
285285
end
286286

287-
Bumblebee.Utils.Nx.composite_unflatten_batch(output, inputs.size)
287+
output
288+
|> Bumblebee.Utils.Nx.composite_unflatten_batch(inputs.size)
289+
|> Shared.serving_post_computation()
288290
end
289291
end
290292

@@ -318,9 +320,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
318320
end
319321

320322
defp client_postprocessing({outputs, _metadata}, multi?, safety_checker?) do
321-
# We use binary backend so we are not blocked by the serving computation
322-
outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend)
323-
324323
for outputs <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
325324
results =
326325
for outputs = %{image: image} <- Bumblebee.Utils.Nx.batch_to_list(outputs) do

lib/bumblebee/shared.ex

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,19 @@ defmodule Bumblebee.Shared do
276276
Nx.Batch.pad(batch, batch_size - size)
277277
end
278278

279+
@doc """
280+
Shared logic applied after serving computation to the resulting tensor
281+
or container.
282+
"""
283+
@spec serving_post_computation(result) :: result when result: Nx.Tensor.t() | Nx.Container.t()
284+
def serving_post_computation(result) do
285+
# We transfer to binary backend so tensor access in post-processing
286+
# is not blocked by the serving the serving computation. It is also
287+
# necessary when partitions are enabled since we may need to
288+
# concatenate results for input exceeding the expected batch size.
289+
Nx.backend_transfer(result, Nx.BinaryBackend)
290+
end
291+
279292
@doc """
280293
Compiles or wraps the function with just-in-time compilation.
281294

lib/bumblebee/text/conversation.ex

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ defmodule Bumblebee.Text.Conversation do
7777
end
7878

7979
sequences[[.., start_idx..-1//1]]
80+
|> Shared.serving_post_computation()
8081
end
8182
end,
8283
defn_options

lib/bumblebee/text/fill_mask.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ defmodule Bumblebee.Text.FillMask do
7575

7676
fn inputs ->
7777
inputs = Shared.maybe_pad(inputs, batch_size)
78-
scores_fun.(params, inputs)
78+
scores_fun.(params, inputs) |> Shared.serving_post_computation()
7979
end
8080
end,
8181
defn_options

lib/bumblebee/text/generation.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ defmodule Bumblebee.Text.Generation do
864864

865865
fn inputs ->
866866
inputs = Shared.maybe_pad(inputs, batch_size)
867-
generate_fun.(params, inputs)
867+
generate_fun.(params, inputs) |> Shared.serving_post_computation()
868868
end
869869
end,
870870
defn_options

lib/bumblebee/text/question_answering.ex

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ defmodule Bumblebee.Text.QuestionAnswering do
6666

6767
fn inputs ->
6868
inputs = Shared.maybe_pad(inputs, batch_size)
69-
70-
predict_fun.(params, inputs)
69+
predict_fun.(params, inputs) |> Shared.serving_post_computation()
7170
end
7271
end,
7372
defn_options
@@ -103,10 +102,6 @@ defmodule Bumblebee.Text.QuestionAnswering do
103102
{batch, {all_inputs, raw_inputs, multi?}}
104103
end)
105104
|> Nx.Serving.client_postprocessing(fn {outputs, _metadata}, {inputs, raw_inputs, multi?} ->
106-
# We use binary backend so we are not blocked by the serving computation
107-
inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend)
108-
outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend)
109-
110105
Enum.zip_with(
111106
[raw_inputs, Utils.Nx.batch_to_list(inputs), Utils.Nx.batch_to_list(outputs)],
112107
fn [{_question_text, context_text}, inputs, outputs] ->

lib/bumblebee/text/text_classification.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ defmodule Bumblebee.Text.TextClassification do
6161

6262
fn inputs ->
6363
inputs = Shared.maybe_pad(inputs, batch_size)
64-
scores_fun.(params, inputs)
64+
scores_fun.(params, inputs) |> Shared.serving_post_computation()
6565
end
6666
end,
6767
defn_options

lib/bumblebee/text/text_embedding.ex

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ defmodule Bumblebee.Text.TextEmbedding do
107107

108108
fn inputs ->
109109
inputs = Shared.maybe_pad(inputs, batch_size)
110-
embedding_fun.(params, inputs)
110+
embedding_fun.(params, inputs) |> Shared.serving_post_computation()
111111
end
112112
end,
113113
defn_options
@@ -131,9 +131,6 @@ defmodule Bumblebee.Text.TextEmbedding do
131131
{batch, multi?}
132132
end)
133133
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
134-
# We use binary backend so we are not blocked by the serving computation
135-
embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend)
136-
137134
for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
138135
%{embedding: embedding}
139136
end

lib/bumblebee/text/token_classification.ex

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ defmodule Bumblebee.Text.TokenClassification do
6161

6262
fn inputs ->
6363
inputs = Shared.maybe_pad(inputs, batch_size)
64-
scores_fun.(params, inputs)
64+
scores_fun.(params, inputs) |> Shared.serving_post_computation()
6565
end
6666
end,
6767
defn_options
@@ -88,10 +88,6 @@ defmodule Bumblebee.Text.TokenClassification do
8888
{batch, {all_inputs, multi?}}
8989
end)
9090
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, {inputs, multi?} ->
91-
# We use binary backend so we are not blocked by the serving computation
92-
scores = Nx.backend_transfer(scores, Nx.BinaryBackend)
93-
inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend)
94-
9591
Enum.zip_with(
9692
Utils.Nx.batch_to_list(inputs),
9793
Utils.Nx.batch_to_list(scores),
@@ -110,9 +106,6 @@ defmodule Bumblebee.Text.TokenClassification do
110106
end
111107

112108
defp gather_raw_entities(scores, tokenizer, inputs) do
113-
# We use binary backend so we are not blocked by the serving computation
114-
scores = Nx.backend_transfer(scores, Nx.BinaryBackend)
115-
116109
{sequence_length, _} = Nx.shape(scores)
117110
flat_special_tokens_mask = Nx.to_flat_list(inputs["special_tokens_mask"])
118111
flat_input_ids = Nx.to_flat_list(inputs["input_ids"])

lib/bumblebee/text/zero_shot_classification.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ defmodule Bumblebee.Text.ZeroShotClassification do
8181
scores = Axon.Activations.softmax(logits[[.., .., entailment_id]])
8282
k = min(top_k, Nx.axis_size(scores, 1))
8383
{top_scores, top_indices} = Nx.top_k(scores, k: k)
84-
{top_scores, top_indices}
84+
{top_scores, top_indices} |> Shared.serving_post_computation()
8585
end
8686
end,
8787
defn_options

lib/bumblebee/vision/image_classification.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ defmodule Bumblebee.Vision.ImageClassification do
5757

5858
fn inputs ->
5959
inputs = Shared.maybe_pad(inputs, batch_size)
60-
scores_fun.(params, inputs)
60+
scores_fun.(params, inputs) |> Shared.serving_post_computation()
6161
end
6262
end,
6363
defn_options

lib/bumblebee/vision/image_embedding.ex

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ defmodule Bumblebee.Vision.ImageEmbedding do
8080

8181
fn inputs ->
8282
inputs = Shared.maybe_pad(inputs, batch_size)
83-
embedding_fun.(params, inputs)
83+
embedding_fun.(params, inputs) |> Shared.serving_post_computation()
8484
end
8585
end,
8686
defn_options
@@ -94,9 +94,6 @@ defmodule Bumblebee.Vision.ImageEmbedding do
9494
{Nx.Batch.concatenate([inputs]), multi?}
9595
end)
9696
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
97-
# We use binary backend so we are not blocked by the serving computation
98-
embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend)
99-
10097
for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
10198
%{embedding: embedding}
10299
end

lib/bumblebee/vision/image_to_text.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ defmodule Bumblebee.Vision.ImageToText do
4949

5050
fn inputs ->
5151
inputs = Shared.maybe_pad(inputs, batch_size)
52-
generate_fun.(params, inputs)
52+
generate_fun.(params, inputs) |> Shared.serving_post_computation()
5353
end
5454
end,
5555
defn_options

test/bumblebee/text/text_embedding_test.exs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ defmodule Bumblebee.Text.TextEmbeddingTest do
9696

9797
text = "query: Cats are cute."
9898

99-
assert %{embedding: %Nx.Tensor{} = embedding1} = Nx.Serving.batched_run(test, text)
100-
assert %{embedding: %Nx.Tensor{} = embedding2} = Nx.Serving.batched_run(test, text)
99+
assert [
100+
%{embedding: %Nx.Tensor{} = embedding1},
101+
%{embedding: %Nx.Tensor{} = embedding2}
102+
] = Nx.Serving.batched_run(test, [text, text])
101103

102104
assert_equal(embedding1, embedding2)
103105
end

0 commit comments

Comments
 (0)