Skip to content

Group all tokenizers undera single module and configure upfront #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 45 additions & 69 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -214,26 +214,26 @@ defmodule Bumblebee do
"whisper" => Bumblebee.Audio.WhisperFeaturizer
}

@model_type_to_tokenizer %{
"albert" => Bumblebee.Text.AlbertTokenizer,
"bart" => Bumblebee.Text.BartTokenizer,
"bert" => Bumblebee.Text.BertTokenizer,
"blenderbot" => Bumblebee.Text.BlenderbotTokenizer,
"blip" => Bumblebee.Text.BertTokenizer,
"distilbert" => Bumblebee.Text.DistilbertTokenizer,
"camembert" => Bumblebee.Text.CamembertTokenizer,
"clip" => Bumblebee.Text.ClipTokenizer,
"gpt_neox" => Bumblebee.Text.GptNeoXTokenizer,
"gpt2" => Bumblebee.Text.Gpt2Tokenizer,
"gpt_bigcode" => Bumblebee.Text.Gpt2Tokenizer,
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
"llama" => Bumblebee.Text.LlamaTokenizer,
"mistral" => Bumblebee.Text.LlamaTokenizer,
"mbart" => Bumblebee.Text.MbartTokenizer,
"roberta" => Bumblebee.Text.RobertaTokenizer,
"t5" => Bumblebee.Text.T5Tokenizer,
"whisper" => Bumblebee.Text.WhisperTokenizer,
"xlm-roberta" => Bumblebee.Text.XlmRobertaTokenizer
@model_type_to_tokenizer_type %{
"albert" => :albert,
"bart" => :bart,
"bert" => :bert,
"blenderbot" => :blenderbot,
"blip" => :bert,
"distilbert" => :distilbert,
"camembert" => :camembert,
"clip" => :clip,
"gpt_neox" => :gpt_neo_x,
"gpt2" => :gpt2,
"gpt_bigcode" => :gpt2,
"layoutlm" => :layout_lm,
"llama" => :llama,
"mistral" => :llama,
"mbart" => :mbart,
"roberta" => :roberta,
"t5" => :t5,
"whisper" => :whisper,
"xlm-roberta" => :xlm_roberta
}

@diffusers_class_to_scheduler %{
Expand Down Expand Up @@ -766,31 +766,6 @@ defmodule Bumblebee do
@doc """
Tokenizes and encodes `input` with the given tokenizer.

## Options

* `:add_special_tokens` - whether to add special tokens. Defaults
to `true`

* `:pad_direction` - the padding direction, either `:right` or
`:left`. Defaults to `:right`

* `:return_attention_mask` - whether to return attention mask for
encoded sequence. Defaults to `true`

* `:return_token_type_ids` - whether to return token type ids for
encoded sequence. Defaults to `true`

* `:return_special_tokens_mask` - whether to return special tokens
mask for encoded sequence. Defaults to `false`

* `:return_offsets` - whether to return token offsets for encoded
sequence. Defaults to `false`

* `:length` - applies fixed length padding or truncation to the
given input if set. Can be either a specific number or a list
of numbers. When a list is given, the smallest number that
exceeds all input lengths is used as the padding length

## Examples

tokenizer = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
Expand All @@ -804,27 +779,28 @@ defmodule Bumblebee do
keyword()
) :: any()
def apply_tokenizer(%module{} = tokenizer, input, opts \\ []) do
opts =
Keyword.validate!(opts,
add_special_tokens: true,
pad_direction: :right,
truncate_direction: :right,
length: nil,
return_attention_mask: true,
return_token_type_ids: true,
return_special_tokens_mask: false,
return_offsets: false
)
tokenizer =
if opts == [] do
tokenizer
else
# TODO: remove options on v0.6
IO.warn(
"passing options to Bumblebee.apply_tokenizer/3 is deprecated," <>
" please use Bumblebee.configure/2 to set tokenizer options"
)

Bumblebee.configure(tokenizer, opts)
end

module.apply(tokenizer, input, opts)
module.apply(tokenizer, input)
end

@doc """
Loads tokenizer from a model repository.

## Options

* `:module` - the tokenizer module. By default it is inferred from
* `:type` - the tokenizer type. By default it is inferred from
the configuration files, if that is not possible, it must be
specified explicitly

Expand All @@ -838,17 +814,17 @@ defmodule Bumblebee do
{:ok, Bumblebee.Tokenizer.t()} | {:error, String.t()}
def load_tokenizer(repository, opts \\ []) do
repository = normalize_repository!(repository)
opts = Keyword.validate!(opts, [:module])
module = opts[:module]
opts = Keyword.validate!(opts, [:type])
type = opts[:type]

case get_repo_files(repository) do
{:ok, %{@tokenizer_filename => etag} = repo_files} ->
with {:ok, path} <- download(repository, @tokenizer_filename, etag) do
module =
module ||
type =
type ||
case infer_tokenizer_type(repository, repo_files) do
{:ok, module} ->
module
{:ok, type} ->
type

{:error, error} ->
raise ArgumentError, "#{error}, please specify the :module option"
Expand Down Expand Up @@ -878,7 +854,7 @@ defmodule Bumblebee do

with {:ok, tokenizer_config} <- tokenizer_config_result,
{:ok, special_tokens_map} <- special_tokens_map_result do
tokenizer = struct!(module)
tokenizer = struct!(Bumblebee.Text.PreTrainedTokenizer, type: type)

tokenizer =
HuggingFace.Transformers.Config.load(tokenizer, %{
Expand Down Expand Up @@ -912,13 +888,13 @@ defmodule Bumblebee do
{:ok, tokenizer_data} <- decode_config(path) do
case tokenizer_data do
%{"model_type" => model_type} ->
case @model_type_to_tokenizer[model_type] do
case @model_type_to_tokenizer_type[model_type] do
nil ->
{:error,
"could not match model type #{inspect(model_type)} to any of the supported tokenizers"}
"could not match model type #{inspect(model_type)} to any of the supported tokenizer types"}

module ->
{:ok, module}
type ->
{:ok, type}
end

_ ->
Expand Down
62 changes: 34 additions & 28 deletions lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

tokenizer =
Bumblebee.configure(tokenizer,
length: sequence_length,
return_token_type_ids: false,
return_attention_mask: false
)

{_, encoder_predict} = Axon.build(encoder.model)
{_, vae_predict} = Axon.build(vae.model)
{_, unet_predict} = Axon.build(unet.model)
Expand Down Expand Up @@ -213,7 +220,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
defn_options
)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer, sequence_length))
|> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer))
|> Nx.Serving.client_postprocessing(&client_postprocessing(&1, &2, safety_checker))
end

Expand All @@ -235,13 +242,10 @@ defmodule Bumblebee.Diffusion.StableDiffusion do

image_fun =
Shared.compile_or_jit(image_fun, defn_options, compile?, fn ->
text_inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32)
}

inputs = %{
"unconditional" => text_inputs,
"conditional" => text_inputs,
"conditional_and_unconditional" => %{
"input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32)
},
"seed" => Nx.template({batch_size}, :s64)
}

Expand Down Expand Up @@ -285,32 +289,22 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
end
end

defp client_preprocessing(input, tokenizer, sequence_length) do
defp client_preprocessing(input, tokenizer) do
{inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1)

prompts = Enum.map(inputs, & &1.prompt)
negative_prompts = Enum.map(inputs, & &1.negative_prompt)
seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend)

conditional =
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, prompts,
length: sequence_length,
return_token_type_ids: false,
return_attention_mask: false
)
end)
# Note: we need to tokenize all sequences together, so that
# they are padded to the same length (if not specified)
prompts = Enum.flat_map(inputs, &[&1.prompt, &1.negative_prompt])

unconditional =
prompt_pairs =
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, negative_prompts,
length: Nx.axis_size(conditional["input_ids"], 1),
return_attention_mask: false,
return_token_type_ids: false
)
inputs = Bumblebee.apply_tokenizer(tokenizer, prompts)
Utils.Nx.composite_unflatten_batch(inputs, Nx.axis_size(seed, 0))
end)

inputs = %{"unconditional" => unconditional, "conditional" => conditional, "seed" => seed}
inputs = %{"conditional_and_unconditional" => prompt_pairs, "seed" => seed}

{Nx.Batch.concatenate([inputs]), multi?}
end
Expand Down Expand Up @@ -360,7 +354,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion do

seed = inputs["seed"]

inputs = Utils.Nx.composite_concatenate(inputs["unconditional"], inputs["conditional"])
inputs =
inputs["conditional_and_unconditional"]
# Transpose conditional and unconditional to separate blocks
|> composite_transpose_leading()
|> Utils.Nx.composite_flatten_batch()

%{hidden_state: text_embeddings} = encoder_predict.(encoder_params, inputs)

Expand Down Expand Up @@ -399,7 +397,8 @@ defmodule Bumblebee.Diffusion.StableDiffusion do

%{sample: noise_pred} = unet_predict.(unet_params, unet_inputs)

{noise_pred_unconditional, noise_pred_text} = split_in_half(noise_pred)
{noise_pred_text, noise_pred_unconditional} =
split_conditional_and_unconditional(noise_pred)

noise_pred =
noise_pred_unconditional + guidance_scale * (noise_pred_text - noise_pred_unconditional)
Expand All @@ -416,7 +415,14 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
NxImage.from_continuous(image, -1, 1)
end

defnp split_in_half(tensor) do
deftransformp composite_transpose_leading(container) do
Utils.Nx.map(container, fn tensor ->
[first, second | rest] = Nx.axes(tensor)
Nx.transpose(tensor, axes: [second, first | rest])
end)
end

defnp split_conditional_and_unconditional(tensor) do
batch_size = Nx.axis_size(tensor, 0)
half_size = div(batch_size, 2)
{tensor[0..(half_size - 1)//1], tensor[half_size..-1//1]}
Expand Down
Loading