@@ -214,26 +214,26 @@ defmodule Bumblebee do
214
214
"whisper" => Bumblebee.Audio.WhisperFeaturizer
215
215
}
216
216
217
- @ model_type_to_tokenizer % {
218
- "albert" => Bumblebee.Text.AlbertTokenizer ,
219
- "bart" => Bumblebee.Text.BartTokenizer ,
220
- "bert" => Bumblebee.Text.BertTokenizer ,
221
- "blenderbot" => Bumblebee.Text.BlenderbotTokenizer ,
222
- "blip" => Bumblebee.Text.BertTokenizer ,
223
- "distilbert" => Bumblebee.Text.DistilbertTokenizer ,
224
- "camembert" => Bumblebee.Text.CamembertTokenizer ,
225
- "clip" => Bumblebee.Text.ClipTokenizer ,
226
- "gpt_neox" => Bumblebee.Text.GptNeoXTokenizer ,
227
- "gpt2" => Bumblebee.Text.Gpt2Tokenizer ,
228
- "gpt_bigcode" => Bumblebee.Text.Gpt2Tokenizer ,
229
- "layoutlm" => Bumblebee.Text.LayoutLmTokenizer ,
230
- "llama" => Bumblebee.Text.LlamaTokenizer ,
231
- "mistral" => Bumblebee.Text.LlamaTokenizer ,
232
- "mbart" => Bumblebee.Text.MbartTokenizer ,
233
- "roberta" => Bumblebee.Text.RobertaTokenizer ,
234
- "t5" => Bumblebee.Text.T5Tokenizer ,
235
- "whisper" => Bumblebee.Text.WhisperTokenizer ,
236
- "xlm-roberta" => Bumblebee.Text.XlmRobertaTokenizer
217
+ @ model_type_to_tokenizer_type % {
218
+ "albert" => :albert ,
219
+ "bart" => :bart ,
220
+ "bert" => :bert ,
221
+ "blenderbot" => :blenderbot ,
222
+ "blip" => :bert ,
223
+ "distilbert" => :distilbert ,
224
+ "camembert" => :camembert ,
225
+ "clip" => :clip ,
226
+ "gpt_neox" => :gpt_neo_x ,
227
+ "gpt2" => :gpt2 ,
228
+ "gpt_bigcode" => :gpt2 ,
229
+ "layoutlm" => :layout_lm ,
230
+ "llama" => :llama ,
231
+ "mistral" => :llama ,
232
+ "mbart" => :mbart ,
233
+ "roberta" => :roberta ,
234
+ "t5" => :t5 ,
235
+ "whisper" => :whisper ,
236
+ "xlm-roberta" => :xlm_roberta
237
237
}
238
238
239
239
@ diffusers_class_to_scheduler % {
@@ -766,31 +766,6 @@ defmodule Bumblebee do
766
766
@ doc """
767
767
Tokenizes and encodes `input` with the given tokenizer.
768
768
769
- ## Options
770
-
771
- * `:add_special_tokens` - whether to add special tokens. Defaults
772
- to `true`
773
-
774
- * `:pad_direction` - the padding direction, either `:right` or
775
- `:left`. Defaults to `:right`
776
-
777
- * `:return_attention_mask` - whether to return attention mask for
778
- encoded sequence. Defaults to `true`
779
-
780
- * `:return_token_type_ids` - whether to return token type ids for
781
- encoded sequence. Defaults to `true`
782
-
783
- * `:return_special_tokens_mask` - whether to return special tokens
784
- mask for encoded sequence. Defaults to `false`
785
-
786
- * `:return_offsets` - whether to return token offsets for encoded
787
- sequence. Defaults to `false`
788
-
789
- * `:length` - applies fixed length padding or truncation to the
790
- given input if set. Can be either a specific number or a list
791
- of numbers. When a list is given, the smallest number that
792
- exceeds all input lengths is used as the padding length
793
-
794
769
## Examples
795
770
796
771
tokenizer = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
@@ -804,27 +779,28 @@ defmodule Bumblebee do
804
779
keyword ( )
805
780
) :: any ( )
806
781
def apply_tokenizer ( % module { } = tokenizer , input , opts \\ [ ] ) do
807
- opts =
808
- Keyword . validate! ( opts ,
809
- add_special_tokens: true ,
810
- pad_direction: :right ,
811
- truncate_direction: :right ,
812
- length: nil ,
813
- return_attention_mask: true ,
814
- return_token_type_ids: true ,
815
- return_special_tokens_mask: false ,
816
- return_offsets: false
817
- )
782
+ tokenizer =
783
+ if opts == [ ] do
784
+ tokenizer
785
+ else
786
+ # TODO: remove options on v0.6
787
+ IO . warn (
788
+ "passing options to Bumblebee.apply_tokenizer/3 is deprecated," <>
789
+ " please use Bumblebee.configure/2 to set tokenizer options"
790
+ )
791
+
792
+ Bumblebee . configure ( tokenizer , opts )
793
+ end
818
794
819
- module . apply ( tokenizer , input , opts )
795
+ module . apply ( tokenizer , input )
820
796
end
821
797
822
798
@ doc """
823
799
Loads tokenizer from a model repository.
824
800
825
801
## Options
826
802
827
- * `:module ` - the tokenizer module . By default it is inferred from
803
+ * `:type ` - the tokenizer type . By default it is inferred from
828
804
the configuration files, if that is not possible, it must be
829
805
specified explicitly
830
806
@@ -838,17 +814,17 @@ defmodule Bumblebee do
838
814
{ :ok , Bumblebee.Tokenizer . t ( ) } | { :error , String . t ( ) }
839
815
def load_tokenizer ( repository , opts \\ [ ] ) do
840
816
repository = normalize_repository! ( repository )
841
- opts = Keyword . validate! ( opts , [ :module ] )
842
- module = opts [ :module ]
817
+ opts = Keyword . validate! ( opts , [ :type ] )
818
+ type = opts [ :type ]
843
819
844
820
case get_repo_files ( repository ) do
845
821
{ :ok , % { @ tokenizer_filename => etag } = repo_files } ->
846
822
with { :ok , path } <- download ( repository , @ tokenizer_filename , etag ) do
847
- module =
848
- module ||
823
+ type =
824
+ type ||
849
825
case infer_tokenizer_type ( repository , repo_files ) do
850
- { :ok , module } ->
851
- module
826
+ { :ok , type } ->
827
+ type
852
828
853
829
{ :error , error } ->
854
830
raise ArgumentError , "#{ error } , please specify the :module option"
@@ -878,7 +854,7 @@ defmodule Bumblebee do
878
854
879
855
with { :ok , tokenizer_config } <- tokenizer_config_result ,
880
856
{ :ok , special_tokens_map } <- special_tokens_map_result do
881
- tokenizer = struct! ( module )
857
+ tokenizer = struct! ( Bumblebee.Text.PreTrainedTokenizer , type: type )
882
858
883
859
tokenizer =
884
860
HuggingFace.Transformers.Config . load ( tokenizer , % {
@@ -912,13 +888,13 @@ defmodule Bumblebee do
912
888
{ :ok , tokenizer_data } <- decode_config ( path ) do
913
889
case tokenizer_data do
914
890
% { "model_type" => model_type } ->
915
- case @ model_type_to_tokenizer [ model_type ] do
891
+ case @ model_type_to_tokenizer_type [ model_type ] do
916
892
nil ->
917
893
{ :error ,
918
- "could not match model type #{ inspect ( model_type ) } to any of the supported tokenizers " }
894
+ "could not match model type #{ inspect ( model_type ) } to any of the supported tokenizer types " }
919
895
920
- module ->
921
- { :ok , module }
896
+ type ->
897
+ { :ok , type }
922
898
end
923
899
924
900
_ ->
0 commit comments