Skip to content

Commit 2af8cf3

Browse files
Move featurizer batch part to serving computation (#243)
1 parent 391fcd0 commit 2af8cf3

15 files changed

+302
-124
lines changed

lib/bumblebee.ex

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,14 @@ defmodule Bumblebee do
560560
@spec apply_featurizer(Bumblebee.Featurizer.t(), any(), keyword()) :: any()
561561
def apply_featurizer(%module{} = featurizer, input, opts \\ []) do
562562
opts = Keyword.validate!(opts, defn_options: [])
563-
module.apply(featurizer, input, opts[:defn_options])
563+
564+
batch = module.process_input(featurizer, input)
565+
566+
if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do
567+
Nx.Defn.jit_apply(&module.process_batch(featurizer, &1), [batch], opts[:defn_options])
568+
else
569+
batch
570+
end
564571
end
565572

566573
@doc """

lib/bumblebee/audio/speech_to_text_whisper.ex

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,18 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
5050
{generate_opts, generation_config} = generate_opts(generation_config, opts)
5151
generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts)
5252

53+
generate_fun = fn params, inputs ->
54+
inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs)
55+
generate_fun.(params, inputs)
56+
end
57+
5358
Nx.Serving.new(
5459
fn defn_options ->
5560
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)
5661

5762
generate_fun =
5863
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
59-
inputs = %{
60-
"input_features" => Shared.input_template(spec, "input_features", [batch_size])
61-
}
62-
64+
inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
6365
[params, inputs]
6466
end)
6567

@@ -102,7 +104,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
102104
all_chunks = List.flatten(all_chunks)
103105
{all_chunks, lengths} = Enum.unzip(all_chunks)
104106

105-
inputs = Bumblebee.apply_featurizer(featurizer, all_chunks, defn_options: defn_options)
107+
inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks)
106108
{Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks, lengths}}
107109
end)
108110
|> maybe_stream(opts[:stream], spec, featurizer, tokenizer, timestamps?)

lib/bumblebee/audio/whisper_featurizer.ex

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do
5353
end
5454

5555
@impl true
56-
def apply(featurizer, raw_samples, defn_options) do
56+
def process_input(featurizer, raw_samples) do
5757
max_length = featurizer.num_seconds * featurizer.sampling_rate
5858

5959
samples =
@@ -67,17 +67,27 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do
6767
Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}])
6868
end
6969

70-
samples = samples |> Nx.stack() |> Nx.vectorize(:batch)
70+
Nx.stack(samples)
71+
end
7172

73+
@impl true
74+
def batch_template(featurizer, batch_size) do
75+
max_length = featurizer.num_seconds * featurizer.sampling_rate
76+
Nx.template({batch_size, max_length}, :f32)
77+
end
78+
79+
@impl true
80+
def process_batch(featurizer, samples) do
7281
samples =
73-
Nx.Defn.jit(&extract_fbank_features/2, defn_options).(samples,
82+
samples
83+
|> Nx.vectorize(:batch)
84+
|> extract_fbank_features(
7485
fft_length: featurizer.fft_length,
7586
sampling_rate: featurizer.sampling_rate,
7687
mel_bins: featurizer.feature_size,
7788
hop_length: featurizer.hop_length
7889
)
79-
80-
samples = Nx.devectorize(samples)
90+
|> Nx.devectorize()
8191

8292
%{"input_features" => samples}
8393
end

lib/bumblebee/featurizer.ex

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,69 @@ defmodule Bumblebee.Featurizer do
1111
@type t :: Bumblebee.Configurable.t()
1212

1313
@doc """
14-
Performs feature extraction on the given input.
14+
Converts the given input to a batched tensor (or a tensor container).
15+
16+
Numerical batch processing should be moved to `c:process_batch/2`
17+
whenever possible.
18+
"""
19+
@callback process_input(t(), input :: any()) :: Nx.t() | Nx.Container.t()
20+
21+
@doc """
22+
Returns an input template for `c:process_batch/2`.
23+
24+
The shape is effectively the same as the result of `c:process_input/2`,
25+
except for the batch size.
26+
"""
27+
@callback batch_template(t(), batch_size :: pos_integer()) :: Nx.t() | Nx.Container.t()
28+
29+
@doc """
30+
Optional batch processing stage.
31+
32+
This is a numerical function. It receives the result of `c:process_input/2`,
33+
except the batch size may differ.
34+
35+
When using featurizer as part of `Nx.Serving`, the batch stage can
36+
be merged with the model computation and compiled together.
37+
"""
38+
@callback process_batch(t(), input :: Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t()
39+
40+
@optional_callbacks batch_template: 2, process_batch: 2
41+
42+
@doc """
43+
Converts the given input to a batched tensor (or a tensor container).
44+
"""
45+
@spec process_input(t(), any()) :: Nx.t() | Nx.Container.t()
46+
def process_input(%module{} = featurizer, input) do
47+
module.process_input(featurizer, input)
48+
end
49+
50+
@doc """
51+
Returns an input template for `process_batch/2`.
52+
53+
If the featurizer does not define batch processing, `nil` is returned.
54+
"""
55+
@spec batch_template(t(), pos_integer()) :: Nx.t() | Nx.Container.t() | nil
56+
def batch_template(%module{} = featurizer, batch_size) do
57+
if Code.ensure_loaded?(module) and function_exported?(module, :batch_template, 2) do
58+
module.batch_template(featurizer, batch_size)
59+
end
60+
end
61+
62+
@doc """
63+
Optional batch processing stage.
64+
65+
This is a numerical function. It receives the result of `c:process_input/2`,
66+
except the batch size may differ.
67+
68+
If the featurizer does not define batch processing, the input is
69+
returned as is.
1570
"""
16-
@callback apply(t(), input :: any(), defn_options :: keyword()) :: any()
71+
@spec process_batch(t(), Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t()
72+
def process_batch(%module{} = featurizer, batch) do
73+
if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do
74+
module.process_batch(featurizer, batch)
75+
else
76+
batch
77+
end
78+
end
1779
end

lib/bumblebee/text/generation.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ defmodule Bumblebee.Text.Generation do
6767
"""
6868
@spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil
6969
def extra_config_module(%module{} = spec) do
70-
if function_exported?(module, :extra_config_module, 1) do
70+
if Code.ensure_loaded?(module) and function_exported?(module, :extra_config_module, 1) do
7171
module.extra_config_module(spec)
7272
end
7373
end

lib/bumblebee/vision/blip_featurizer.ex

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,34 @@ defmodule Bumblebee.Vision.BlipFeaturizer do
5454
end
5555

5656
@impl true
57-
def apply(featurizer, images, _defn_options) do
57+
def process_input(featurizer, images) do
5858
images = List.wrap(images)
5959

60-
images =
61-
for image <- images do
62-
images =
63-
image
64-
|> Image.to_batched_tensor()
65-
|> Nx.as_type(:f32)
66-
|> Image.normalize_channels(length(featurizer.image_mean))
67-
68-
if featurizer.resize do
69-
size = Image.normalize_size(featurizer.size)
70-
NxImage.resize(images, size, method: featurizer.resize_method)
71-
else
72-
images
73-
end
60+
for image <- images do
61+
images =
62+
image
63+
|> Image.to_batched_tensor()
64+
|> Nx.as_type(:f32)
65+
|> Image.normalize_channels(length(featurizer.image_mean))
66+
67+
if featurizer.resize do
68+
size = Image.normalize_size(featurizer.size)
69+
NxImage.resize(images, size, method: featurizer.resize_method)
70+
else
71+
images
7472
end
75-
|> Nx.concatenate()
73+
end
74+
|> Nx.concatenate()
75+
end
7676

77+
@impl true
78+
def batch_template(featurizer, batch_size) do
79+
num_channels = length(featurizer.image_mean)
80+
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
81+
end
82+
83+
@impl true
84+
def process_batch(featurizer, images) do
7785
images = NxImage.to_continuous(images, 0, 1)
7886

7987
images =

lib/bumblebee/vision/clip_featurizer.ex

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,32 +65,40 @@ defmodule Bumblebee.Vision.ClipFeaturizer do
6565
end
6666

6767
@impl true
68-
def apply(featurizer, images, _defn_options) do
68+
def process_input(featurizer, images) do
6969
images = List.wrap(images)
7070

71-
images =
72-
for image <- images do
73-
images =
74-
image
75-
|> Image.to_batched_tensor()
76-
|> Nx.as_type(:f32)
77-
|> Image.normalize_channels(length(featurizer.image_mean))
78-
79-
images =
80-
if featurizer.resize do
81-
NxImage.resize_short(images, featurizer.size, method: featurizer.resize_method)
82-
else
83-
images
84-
end
85-
86-
if featurizer.center_crop do
87-
NxImage.center_crop(images, {featurizer.crop_size, featurizer.crop_size})
71+
for image <- images do
72+
images =
73+
image
74+
|> Image.to_batched_tensor()
75+
|> Nx.as_type(:f32)
76+
|> Image.normalize_channels(length(featurizer.image_mean))
77+
78+
images =
79+
if featurizer.resize do
80+
NxImage.resize_short(images, featurizer.size, method: featurizer.resize_method)
8881
else
8982
images
9083
end
84+
85+
if featurizer.center_crop do
86+
NxImage.center_crop(images, {featurizer.crop_size, featurizer.crop_size})
87+
else
88+
images
9189
end
92-
|> Nx.concatenate()
90+
end
91+
|> Nx.concatenate()
92+
end
9393

94+
@impl true
95+
def batch_template(featurizer, batch_size) do
96+
num_channels = length(featurizer.image_mean)
97+
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
98+
end
99+
100+
@impl true
101+
def process_batch(featurizer, images) do
94102
images = NxImage.to_continuous(images, 0, 1)
95103

96104
images =

lib/bumblebee/vision/convnext_featurizer.ex

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,36 +60,44 @@ defmodule Bumblebee.Vision.ConvNextFeaturizer do
6060
end
6161

6262
@impl true
63-
def apply(featurizer, images, _defn_options) do
63+
def process_input(featurizer, images) do
6464
images = List.wrap(images)
6565

66-
images =
67-
for image <- images do
68-
images =
69-
image
70-
|> Image.to_batched_tensor()
71-
|> Nx.as_type(:f32)
72-
|> Image.normalize_channels(length(featurizer.image_mean))
73-
74-
cond do
75-
not featurizer.resize ->
76-
images
77-
78-
featurizer.size >= 384 ->
79-
NxImage.resize(images, {featurizer.size, featurizer.size},
80-
method: featurizer.resize_method
81-
)
82-
83-
true ->
84-
scale_size = floor(featurizer.size / featurizer.crop_percentage)
85-
86-
images
87-
|> NxImage.resize_short(scale_size, method: featurizer.resize_method)
88-
|> NxImage.center_crop({featurizer.size, featurizer.size})
89-
end
66+
for image <- images do
67+
images =
68+
image
69+
|> Image.to_batched_tensor()
70+
|> Nx.as_type(:f32)
71+
|> Image.normalize_channels(length(featurizer.image_mean))
72+
73+
cond do
74+
not featurizer.resize ->
75+
images
76+
77+
featurizer.size >= 384 ->
78+
NxImage.resize(images, {featurizer.size, featurizer.size},
79+
method: featurizer.resize_method
80+
)
81+
82+
true ->
83+
scale_size = floor(featurizer.size / featurizer.crop_percentage)
84+
85+
images
86+
|> NxImage.resize_short(scale_size, method: featurizer.resize_method)
87+
|> NxImage.center_crop({featurizer.size, featurizer.size})
9088
end
91-
|> Nx.concatenate()
89+
end
90+
|> Nx.concatenate()
91+
end
9292

93+
@impl true
94+
def batch_template(featurizer, batch_size) do
95+
num_channels = length(featurizer.image_mean)
96+
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
97+
end
98+
99+
@impl true
100+
def process_batch(featurizer, images) do
93101
images = NxImage.to_continuous(images, 0, 1)
94102

95103
images =

0 commit comments

Comments
 (0)