Skip to content

Move featurizer batch part to serving computation #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,14 @@ defmodule Bumblebee do
@spec apply_featurizer(Bumblebee.Featurizer.t(), any(), keyword()) :: any()
def apply_featurizer(%module{} = featurizer, input, opts \\ []) do
opts = Keyword.validate!(opts, defn_options: [])
module.apply(featurizer, input, opts[:defn_options])

batch = module.process_input(featurizer, input)

if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do
Nx.Defn.jit_apply(&module.process_batch(featurizer, &1), [batch], opts[:defn_options])
else
batch
end
end

@doc """
Expand Down
12 changes: 7 additions & 5 deletions lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,18 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
{generate_opts, generation_config} = generate_opts(generation_config, opts)
generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts)

generate_fun = fn params, inputs ->
inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs)
generate_fun.(params, inputs)
end

Nx.Serving.new(
fn defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
inputs = %{
"input_features" => Shared.input_template(spec, "input_features", [batch_size])
}

inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
[params, inputs]
end)

Expand Down Expand Up @@ -102,7 +104,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
all_chunks = List.flatten(all_chunks)
{all_chunks, lengths} = Enum.unzip(all_chunks)

inputs = Bumblebee.apply_featurizer(featurizer, all_chunks, defn_options: defn_options)
inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks)
{Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks, lengths}}
end)
|> maybe_stream(opts[:stream], spec, featurizer, tokenizer, timestamps?)
Expand Down
20 changes: 15 additions & 5 deletions lib/bumblebee/audio/whisper_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do
end

@impl true
def apply(featurizer, raw_samples, defn_options) do
def process_input(featurizer, raw_samples) do
max_length = featurizer.num_seconds * featurizer.sampling_rate

samples =
Expand All @@ -67,17 +67,27 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do
Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}])
end

samples = samples |> Nx.stack() |> Nx.vectorize(:batch)
Nx.stack(samples)
end

@impl true
def batch_template(featurizer, batch_size) do
max_length = featurizer.num_seconds * featurizer.sampling_rate
Nx.template({batch_size, max_length}, :f32)
end

@impl true
def process_batch(featurizer, samples) do
samples =
Nx.Defn.jit(&extract_fbank_features/2, defn_options).(samples,
samples
|> Nx.vectorize(:batch)
|> extract_fbank_features(
fft_length: featurizer.fft_length,
sampling_rate: featurizer.sampling_rate,
mel_bins: featurizer.feature_size,
hop_length: featurizer.hop_length
)

samples = Nx.devectorize(samples)
|> Nx.devectorize()

%{"input_features" => samples}
end
Expand Down
66 changes: 64 additions & 2 deletions lib/bumblebee/featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,69 @@ defmodule Bumblebee.Featurizer do
@type t :: Bumblebee.Configurable.t()

@doc """
Performs feature extraction on the given input.
Converts the given input to a batched tensor (or a tensor container).

Numerical batch processing should be moved to `c:process_batch/2`
whenever possible.
"""
@callback process_input(t(), input :: any()) :: Nx.t() | Nx.Container.t()

@doc """
Returns an input template for `c:process_batch/2`.

The shape is effectively the same as the result of `c:process_input/2`,
except for the batch size.
"""
@callback batch_template(t(), batch_size :: pos_integer()) :: Nx.t() | Nx.Container.t()

@doc """
Optional batch processing stage.

This is a numerical function. It receives the result of `c:process_input/2`,
except the batch size may differ.

When using featurizer as part of `Nx.Serving`, the batch stage can
be merged with the model computation and compiled together.
"""
@callback process_batch(t(), input :: Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t()

@optional_callbacks batch_template: 2, process_batch: 2

@doc """
Converts the given input to a batched tensor (or a tensor container).
"""
@spec process_input(t(), any()) :: Nx.t() | Nx.Container.t()
def process_input(%module{} = featurizer, input) do
module.process_input(featurizer, input)
end

@doc """
Returns an input template for `process_batch/2`.

If the featurizer does not define batch processing, `nil` is returned.
"""
@spec batch_template(t(), pos_integer()) :: Nx.t() | Nx.Container.t() | nil
def batch_template(%module{} = featurizer, batch_size) do
if Code.ensure_loaded?(module) and function_exported?(module, :batch_template, 2) do
module.batch_template(featurizer, batch_size)
end
end

@doc """
Optional batch processing stage.

This is a numerical function. It receives the result of `c:process_input/2`,
except the batch size may differ.

If the featurizer does not define batch processing, the input is
returned as is.
"""
@callback apply(t(), input :: any(), defn_options :: keyword()) :: any()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't deprecate, because it's very unlikely that someone implements a featurizer outside bumblebee.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful. If you want to keep backwards compatibility, you could keep it as a apply and introduce apply_batch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather check for it as fallback, I just don't think it's worth in this case.

As for naming, I didn't go with apply_batch because then it seems as if apply_batch were batched version of apply. I'm not sure the current naming is perfect either, but the best I come up with :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your call, just mentioning for completeness.

@spec process_batch(t(), Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t()
def process_batch(%module{} = featurizer, batch) do
if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do
module.process_batch(featurizer, batch)
else
batch
end
end
end
2 changes: 1 addition & 1 deletion lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ defmodule Bumblebee.Text.Generation do
"""
@spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil
def extra_config_module(%module{} = spec) do
if function_exported?(module, :extra_config_module, 1) do
if Code.ensure_loaded?(module) and function_exported?(module, :extra_config_module, 1) do
module.extra_config_module(spec)
end
end
Expand Down
40 changes: 24 additions & 16 deletions lib/bumblebee/vision/blip_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,34 @@ defmodule Bumblebee.Vision.BlipFeaturizer do
end

@impl true
def apply(featurizer, images, _defn_options) do
def process_input(featurizer, images) do
images = List.wrap(images)

images =
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

if featurizer.resize do
size = Image.normalize_size(featurizer.size)
NxImage.resize(images, size, method: featurizer.resize_method)
else
images
end
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

if featurizer.resize do
size = Image.normalize_size(featurizer.size)
NxImage.resize(images, size, method: featurizer.resize_method)
else
images
end
|> Nx.concatenate()
end
|> Nx.concatenate()
end

@impl true
def batch_template(featurizer, batch_size) do
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
end

@impl true
def process_batch(featurizer, images) do
images = NxImage.to_continuous(images, 0, 1)

images =
Expand Down
46 changes: 27 additions & 19 deletions lib/bumblebee/vision/clip_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -65,32 +65,40 @@ defmodule Bumblebee.Vision.ClipFeaturizer do
end

@impl true
def apply(featurizer, images, _defn_options) do
def process_input(featurizer, images) do
images = List.wrap(images)

images =
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

images =
if featurizer.resize do
NxImage.resize_short(images, featurizer.size, method: featurizer.resize_method)
else
images
end

if featurizer.center_crop do
NxImage.center_crop(images, {featurizer.crop_size, featurizer.crop_size})
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

images =
if featurizer.resize do
NxImage.resize_short(images, featurizer.size, method: featurizer.resize_method)
else
images
end

if featurizer.center_crop do
NxImage.center_crop(images, {featurizer.crop_size, featurizer.crop_size})
else
images
end
|> Nx.concatenate()
end
|> Nx.concatenate()
end

@impl true
def batch_template(featurizer, batch_size) do
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
end

@impl true
def process_batch(featurizer, images) do
images = NxImage.to_continuous(images, 0, 1)

images =
Expand Down
60 changes: 34 additions & 26 deletions lib/bumblebee/vision/convnext_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -60,36 +60,44 @@ defmodule Bumblebee.Vision.ConvNextFeaturizer do
end

@impl true
def apply(featurizer, images, _defn_options) do
def process_input(featurizer, images) do
images = List.wrap(images)

images =
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

cond do
not featurizer.resize ->
images

featurizer.size >= 384 ->
NxImage.resize(images, {featurizer.size, featurizer.size},
method: featurizer.resize_method
)

true ->
scale_size = floor(featurizer.size / featurizer.crop_percentage)

images
|> NxImage.resize_short(scale_size, method: featurizer.resize_method)
|> NxImage.center_crop({featurizer.size, featurizer.size})
end
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

cond do
not featurizer.resize ->
images

featurizer.size >= 384 ->
NxImage.resize(images, {featurizer.size, featurizer.size},
method: featurizer.resize_method
)

true ->
scale_size = floor(featurizer.size / featurizer.crop_percentage)

images
|> NxImage.resize_short(scale_size, method: featurizer.resize_method)
|> NxImage.center_crop({featurizer.size, featurizer.size})
end
|> Nx.concatenate()
end
|> Nx.concatenate()
end

@impl true
def batch_template(featurizer, batch_size) do
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
end

@impl true
def process_batch(featurizer, images) do
images = NxImage.to_continuous(images, 0, 1)

images =
Expand Down
Loading