Skip to content

Commit 8b52612

Browse files
Group all tokenizers undera single module and configure upfront (#310)
1 parent b141042 commit 8b52612

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1050
-1234
lines changed

lib/bumblebee.ex

Lines changed: 45 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -214,26 +214,26 @@ defmodule Bumblebee do
214214
"whisper" => Bumblebee.Audio.WhisperFeaturizer
215215
}
216216

217-
@model_type_to_tokenizer %{
218-
"albert" => Bumblebee.Text.AlbertTokenizer,
219-
"bart" => Bumblebee.Text.BartTokenizer,
220-
"bert" => Bumblebee.Text.BertTokenizer,
221-
"blenderbot" => Bumblebee.Text.BlenderbotTokenizer,
222-
"blip" => Bumblebee.Text.BertTokenizer,
223-
"distilbert" => Bumblebee.Text.DistilbertTokenizer,
224-
"camembert" => Bumblebee.Text.CamembertTokenizer,
225-
"clip" => Bumblebee.Text.ClipTokenizer,
226-
"gpt_neox" => Bumblebee.Text.GptNeoXTokenizer,
227-
"gpt2" => Bumblebee.Text.Gpt2Tokenizer,
228-
"gpt_bigcode" => Bumblebee.Text.Gpt2Tokenizer,
229-
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
230-
"llama" => Bumblebee.Text.LlamaTokenizer,
231-
"mistral" => Bumblebee.Text.LlamaTokenizer,
232-
"mbart" => Bumblebee.Text.MbartTokenizer,
233-
"roberta" => Bumblebee.Text.RobertaTokenizer,
234-
"t5" => Bumblebee.Text.T5Tokenizer,
235-
"whisper" => Bumblebee.Text.WhisperTokenizer,
236-
"xlm-roberta" => Bumblebee.Text.XlmRobertaTokenizer
217+
@model_type_to_tokenizer_type %{
218+
"albert" => :albert,
219+
"bart" => :bart,
220+
"bert" => :bert,
221+
"blenderbot" => :blenderbot,
222+
"blip" => :bert,
223+
"distilbert" => :distilbert,
224+
"camembert" => :camembert,
225+
"clip" => :clip,
226+
"gpt_neox" => :gpt_neo_x,
227+
"gpt2" => :gpt2,
228+
"gpt_bigcode" => :gpt2,
229+
"layoutlm" => :layout_lm,
230+
"llama" => :llama,
231+
"mistral" => :llama,
232+
"mbart" => :mbart,
233+
"roberta" => :roberta,
234+
"t5" => :t5,
235+
"whisper" => :whisper,
236+
"xlm-roberta" => :xlm_roberta
237237
}
238238

239239
@diffusers_class_to_scheduler %{
@@ -766,31 +766,6 @@ defmodule Bumblebee do
766766
@doc """
767767
Tokenizes and encodes `input` with the given tokenizer.
768768
769-
## Options
770-
771-
* `:add_special_tokens` - whether to add special tokens. Defaults
772-
to `true`
773-
774-
* `:pad_direction` - the padding direction, either `:right` or
775-
`:left`. Defaults to `:right`
776-
777-
* `:return_attention_mask` - whether to return attention mask for
778-
encoded sequence. Defaults to `true`
779-
780-
* `:return_token_type_ids` - whether to return token type ids for
781-
encoded sequence. Defaults to `true`
782-
783-
* `:return_special_tokens_mask` - whether to return special tokens
784-
mask for encoded sequence. Defaults to `false`
785-
786-
* `:return_offsets` - whether to return token offsets for encoded
787-
sequence. Defaults to `false`
788-
789-
* `:length` - applies fixed length padding or truncation to the
790-
given input if set. Can be either a specific number or a list
791-
of numbers. When a list is given, the smallest number that
792-
exceeds all input lengths is used as the padding length
793-
794769
## Examples
795770
796771
tokenizer = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
@@ -804,27 +779,28 @@ defmodule Bumblebee do
804779
keyword()
805780
) :: any()
806781
def apply_tokenizer(%module{} = tokenizer, input, opts \\ []) do
807-
opts =
808-
Keyword.validate!(opts,
809-
add_special_tokens: true,
810-
pad_direction: :right,
811-
truncate_direction: :right,
812-
length: nil,
813-
return_attention_mask: true,
814-
return_token_type_ids: true,
815-
return_special_tokens_mask: false,
816-
return_offsets: false
817-
)
782+
tokenizer =
783+
if opts == [] do
784+
tokenizer
785+
else
786+
# TODO: remove options on v0.6
787+
IO.warn(
788+
"passing options to Bumblebee.apply_tokenizer/3 is deprecated," <>
789+
" please use Bumblebee.configure/2 to set tokenizer options"
790+
)
791+
792+
Bumblebee.configure(tokenizer, opts)
793+
end
818794

819-
module.apply(tokenizer, input, opts)
795+
module.apply(tokenizer, input)
820796
end
821797

822798
@doc """
823799
Loads tokenizer from a model repository.
824800
825801
## Options
826802
827-
* `:module` - the tokenizer module. By default it is inferred from
803+
* `:type` - the tokenizer type. By default it is inferred from
828804
the configuration files, if that is not possible, it must be
829805
specified explicitly
830806
@@ -838,17 +814,17 @@ defmodule Bumblebee do
838814
{:ok, Bumblebee.Tokenizer.t()} | {:error, String.t()}
839815
def load_tokenizer(repository, opts \\ []) do
840816
repository = normalize_repository!(repository)
841-
opts = Keyword.validate!(opts, [:module])
842-
module = opts[:module]
817+
opts = Keyword.validate!(opts, [:type])
818+
type = opts[:type]
843819

844820
case get_repo_files(repository) do
845821
{:ok, %{@tokenizer_filename => etag} = repo_files} ->
846822
with {:ok, path} <- download(repository, @tokenizer_filename, etag) do
847-
module =
848-
module ||
823+
type =
824+
type ||
849825
case infer_tokenizer_type(repository, repo_files) do
850-
{:ok, module} ->
851-
module
826+
{:ok, type} ->
827+
type
852828

853829
{:error, error} ->
854830
raise ArgumentError, "#{error}, please specify the :module option"
@@ -878,7 +854,7 @@ defmodule Bumblebee do
878854

879855
with {:ok, tokenizer_config} <- tokenizer_config_result,
880856
{:ok, special_tokens_map} <- special_tokens_map_result do
881-
tokenizer = struct!(module)
857+
tokenizer = struct!(Bumblebee.Text.PreTrainedTokenizer, type: type)
882858

883859
tokenizer =
884860
HuggingFace.Transformers.Config.load(tokenizer, %{
@@ -912,13 +888,13 @@ defmodule Bumblebee do
912888
{:ok, tokenizer_data} <- decode_config(path) do
913889
case tokenizer_data do
914890
%{"model_type" => model_type} ->
915-
case @model_type_to_tokenizer[model_type] do
891+
case @model_type_to_tokenizer_type[model_type] do
916892
nil ->
917893
{:error,
918-
"could not match model type #{inspect(model_type)} to any of the supported tokenizers"}
894+
"could not match model type #{inspect(model_type)} to any of the supported tokenizer types"}
919895

920-
module ->
921-
{:ok, module}
896+
type ->
897+
{:ok, type}
922898
end
923899

924900
_ ->

lib/bumblebee/diffusion/stable_diffusion.ex

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
162162
batch_size = compile[:batch_size]
163163
sequence_length = compile[:sequence_length]
164164

165+
tokenizer =
166+
Bumblebee.configure(tokenizer,
167+
length: sequence_length,
168+
return_token_type_ids: false,
169+
return_attention_mask: false
170+
)
171+
165172
{_, encoder_predict} = Axon.build(encoder.model)
166173
{_, vae_predict} = Axon.build(vae.model)
167174
{_, unet_predict} = Axon.build(unet.model)
@@ -213,7 +220,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
213220
defn_options
214221
)
215222
|> Nx.Serving.batch_size(batch_size)
216-
|> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer, sequence_length))
223+
|> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer))
217224
|> Nx.Serving.client_postprocessing(&client_postprocessing(&1, &2, safety_checker))
218225
end
219226

@@ -235,13 +242,10 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
235242

236243
image_fun =
237244
Shared.compile_or_jit(image_fun, defn_options, compile?, fn ->
238-
text_inputs = %{
239-
"input_ids" => Nx.template({batch_size, sequence_length}, :u32)
240-
}
241-
242245
inputs = %{
243-
"unconditional" => text_inputs,
244-
"conditional" => text_inputs,
246+
"conditional_and_unconditional" => %{
247+
"input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32)
248+
},
245249
"seed" => Nx.template({batch_size}, :s64)
246250
}
247251

@@ -285,32 +289,22 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
285289
end
286290
end
287291

288-
defp client_preprocessing(input, tokenizer, sequence_length) do
292+
defp client_preprocessing(input, tokenizer) do
289293
{inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1)
290294

291-
prompts = Enum.map(inputs, & &1.prompt)
292-
negative_prompts = Enum.map(inputs, & &1.negative_prompt)
293295
seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend)
294296

295-
conditional =
296-
Nx.with_default_backend(Nx.BinaryBackend, fn ->
297-
Bumblebee.apply_tokenizer(tokenizer, prompts,
298-
length: sequence_length,
299-
return_token_type_ids: false,
300-
return_attention_mask: false
301-
)
302-
end)
297+
# Note: we need to tokenize all sequences together, so that
298+
# they are padded to the same length (if not specified)
299+
prompts = Enum.flat_map(inputs, &[&1.prompt, &1.negative_prompt])
303300

304-
unconditional =
301+
prompt_pairs =
305302
Nx.with_default_backend(Nx.BinaryBackend, fn ->
306-
Bumblebee.apply_tokenizer(tokenizer, negative_prompts,
307-
length: Nx.axis_size(conditional["input_ids"], 1),
308-
return_attention_mask: false,
309-
return_token_type_ids: false
310-
)
303+
inputs = Bumblebee.apply_tokenizer(tokenizer, prompts)
304+
Utils.Nx.composite_unflatten_batch(inputs, Nx.axis_size(seed, 0))
311305
end)
312306

313-
inputs = %{"unconditional" => unconditional, "conditional" => conditional, "seed" => seed}
307+
inputs = %{"conditional_and_unconditional" => prompt_pairs, "seed" => seed}
314308

315309
{Nx.Batch.concatenate([inputs]), multi?}
316310
end
@@ -360,7 +354,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
360354

361355
seed = inputs["seed"]
362356

363-
inputs = Utils.Nx.composite_concatenate(inputs["unconditional"], inputs["conditional"])
357+
inputs =
358+
inputs["conditional_and_unconditional"]
359+
# Transpose conditional and unconditional to separate blocks
360+
|> composite_transpose_leading()
361+
|> Utils.Nx.composite_flatten_batch()
364362

365363
%{hidden_state: text_embeddings} = encoder_predict.(encoder_params, inputs)
366364

@@ -399,7 +397,8 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
399397

400398
%{sample: noise_pred} = unet_predict.(unet_params, unet_inputs)
401399

402-
{noise_pred_unconditional, noise_pred_text} = split_in_half(noise_pred)
400+
{noise_pred_text, noise_pred_unconditional} =
401+
split_conditional_and_unconditional(noise_pred)
403402

404403
noise_pred =
405404
noise_pred_unconditional + guidance_scale * (noise_pred_text - noise_pred_unconditional)
@@ -416,7 +415,14 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
416415
NxImage.from_continuous(image, -1, 1)
417416
end
418417

419-
defnp split_in_half(tensor) do
418+
deftransformp composite_transpose_leading(container) do
419+
Utils.Nx.map(container, fn tensor ->
420+
[first, second | rest] = Nx.axes(tensor)
421+
Nx.transpose(tensor, axes: [second, first | rest])
422+
end)
423+
end
424+
425+
defnp split_conditional_and_unconditional(tensor) do
420426
batch_size = Nx.axis_size(tensor, 0)
421427
half_size = div(batch_size, 2)
422428
{tensor[0..(half_size - 1)//1], tensor[half_size..-1//1]}

0 commit comments

Comments
 (0)