Skip to content

Commit cf4dff6

Browse files
Add Starcoder (#294)
1 parent 3cc0ff9 commit cf4dff6

31 files changed

+840
-154
lines changed

lib/bumblebee.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ defmodule Bumblebee do
128128
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
129129
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
130130
"GPT2Model" => {BumbleBee.Text.Gpt2, :base},
131+
"GPTBigCodeModel" => {Bumblebee.Text.GptBigCode, :base},
132+
"GPTBigCodeForCausalLM" => {Bumblebee.Text.GptBigCode, :for_causal_language_modeling},
133+
"GPTBigCodeForSequenceClassification" =>
134+
{Bumblebee.Text.GptBigCode, :for_sequence_classification},
135+
"GPTBigCodeForTokenClassification" => {Bumblebee.Text.GptBigCode, :for_token_classification},
131136
"GPTNeoXModel" => {Bumblebee.Text.GptNeoX, :base},
132137
"GPTNeoXForCausalLM" => {Bumblebee.Text.GptNeoX, :for_causal_language_modeling},
133138
"GPTNeoXForSequenceClassification" => {Bumblebee.Text.GptNeoX, :for_sequence_classification},
@@ -215,6 +220,7 @@ defmodule Bumblebee do
215220
"clip" => Bumblebee.Text.ClipTokenizer,
216221
"gpt_neox" => Bumblebee.Text.GptNeoXTokenizer,
217222
"gpt2" => Bumblebee.Text.Gpt2Tokenizer,
223+
"gpt_bigcode" => Bumblebee.Text.Gpt2Tokenizer,
218224
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
219225
"llama" => Bumblebee.Text.LlamaTokenizer,
220226
"mistral" => Bumblebee.Text.LlamaTokenizer,

lib/bumblebee/audio/whisper.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ defmodule Bumblebee.Audio.Whisper do
469469
cross_hidden_state: encoder_hidden_state,
470470
cross_attention_head_mask: cross_attention_head_mask,
471471
cache: cache,
472-
causal?: true,
472+
causal: true,
473473
num_blocks: spec.decoder_num_blocks,
474474
num_attention_heads: spec.decoder_num_attention_heads,
475475
hidden_size: spec.hidden_size,
@@ -520,7 +520,7 @@ defmodule Bumblebee.Audio.Whisper do
520520
decoder_num_attention_heads: {"decoder_attention_heads", number()},
521521
encoder_intermediate_size: {"encoder_ffn_dim", number()},
522522
decoder_intermediate_size: {"decoder_ffn_dim", number()},
523-
activation: {"activation_function", atom()},
523+
activation: {"activation_function", activation()},
524524
dropout_rate: {"dropout", number()},
525525
attention_dropout_rate: {"attention_dropout", number()},
526526
activation_dropout_rate: {"activation_dropout", number()},

lib/bumblebee/diffusion/unet_2d_conditional.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do
397397
num_attention_heads: {"attention_head_dim", one_of([number(), list(number())])},
398398
cross_attention_size: {"cross_attention_dim", number()},
399399
use_linear_projection: {"use_linear_projection", boolean()},
400-
activation: {"act_fn", atom()},
400+
activation: {"act_fn", activation()},
401401
group_norm_num_groups: {"norm_num_groups", number()},
402402
group_norm_epsilon: {"norm_eps", number()}
403403
)

lib/bumblebee/diffusion/vae_kl.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ defmodule Bumblebee.Diffusion.VaeKl do
435435
down_block_types:
436436
{"down_block_types", list(mapping(%{"DownEncoderBlock2D" => :down_block}))},
437437
up_block_types: {"up_block_types", list(mapping(%{"UpDecoderBlock2D" => :up_block}))},
438-
activation: {"act_fn", atom()}
438+
activation: {"act_fn", activation()}
439439
)
440440

441441
@for.config(spec, opts)

lib/bumblebee/layers.ex

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ defmodule Bumblebee.Layers do
33

44
import Nx.Defn
55

6-
@unsupported_activations [:gelu_new, :quick_gelu]
6+
@unsupported_activations [:gelu_approx_tanh, :gelu_approx_sigmoid]
77

88
@pi :math.pi()
99

@@ -30,17 +30,29 @@ defmodule Bumblebee.Layers do
3030
end
3131

3232
@doc """
33-
Implements the GeLU new activation from huggingface/transformers.
33+
Implements the GeLU activation approximated with tanh.
34+
35+
## References
36+
37+
* [Gaussian Error Linear Units (GeLUs)](https://arxiv.org/pdf/1606.08415.pdf)
38+
3439
"""
35-
defn gelu_new(input, _opts \\ []) do
40+
defn gelu_approx_tanh(input, _opts \\ []) do
3641
0.5 * input *
3742
(1.0 + Nx.tanh(Nx.sqrt(2.0 / @pi) * (input + 0.044715 * Nx.pow(input, 3.0))))
3843
end
3944

4045
@doc """
41-
Implements the GeLU quick activation from huggingface/transformers.
46+
Implements the GeLU activation approximated with sigmoid.
47+
48+
Note that this approximation is less accurate than `gelu_approx_tanh/2`.
49+
50+
## References
51+
52+
* [Gaussian Error Linear Units (GeLUs)](https://arxiv.org/pdf/1606.08415.pdf)
53+
4254
"""
43-
defn quick_gelu(input, _opts \\ []) do
55+
defn gelu_approx_sigmoid(input, _opts \\ []) do
4456
input * Nx.sigmoid(1.702 * input)
4557
end
4658

@@ -184,28 +196,29 @@ defmodule Bumblebee.Layers do
184196
185197
## Options
186198
187-
* `:scale_query?` - whether to scale the query. Defaults to `true`
199+
* `:scale` - whether to scale the weights. Defaults to `true`
188200
189201
"""
190202
def attention_weights(query, key, bias, opts \\ []) do
191203
Axon.layer(&attention_weights_impl/4, [query, key, bias], opts)
192204
end
193205

194206
defnp attention_weights_impl(query, key, bias, opts \\ []) do
195-
opts = keyword!(opts, mode: :train, scale_query?: true)
207+
opts = keyword!(opts, mode: :train, scale: true)
196208

197209
key = Nx.transpose(key, axes: [0, 2, 1, 3])
198210
query = Nx.transpose(query, axes: [0, 2, 1, 3])
199211

200-
query =
201-
if opts[:scale_query?] do
212+
weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])
213+
214+
weights =
215+
if opts[:scale] do
202216
depth = Nx.axis_size(query, -1)
203-
query / Nx.sqrt(depth)
217+
weights / Nx.sqrt(depth)
204218
else
205-
query
219+
weights
206220
end
207221

208-
weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])
209222
weights = weights + bias
210223
Axon.Activations.softmax(weights, axis: -1)
211224
end

lib/bumblebee/layers/transformer.ex

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ defmodule Bumblebee.Layers.Transformer do
4343
block_opts_keys = [
4444
:num_attention_heads,
4545
:num_key_value_heads,
46-
:causal?,
46+
:causal,
4747
:hidden_size,
4848
:ffn,
4949
:kernel_initializer,
@@ -56,7 +56,7 @@ defmodule Bumblebee.Layers.Transformer do
5656
:output_use_bias,
5757
:layer_norm,
5858
:block_type,
59-
:scale_query?,
59+
:scale_attention_weights,
6060
:rotary_embedding
6161
]
6262

@@ -216,7 +216,7 @@ defmodule Bumblebee.Layers.Transformer do
216216
217217
* `:offset` - offset in the input sequence during iterative decoding
218218
219-
* `:causal?` - whether the self-attention block should be causal.
219+
* `:causal` - whether the self-attention block should be causal.
220220
Defaults to `false`
221221
222222
* `:kernel_initializer` - initializer for kernel weights. Defaults
@@ -265,7 +265,7 @@ defmodule Bumblebee.Layers.Transformer do
265265
* `:parallel` - block with attention and FFN independently (in parallel).
266266
This type doesn't support cross-attention
267267
268-
* `:scale_query?` - whether to scale query in the traditional style of
268+
* `:scale_attention_weights` - whether to scale query in the traditional style of
269269
multi-headed attention. Defaults to `true`
270270
271271
* `:rotary_embedding` - configuration of rotary embedding. If set,
@@ -308,7 +308,7 @@ defmodule Bumblebee.Layers.Transformer do
308308
cross_attention_head_mask: Layers.none(),
309309
block_cache: Layers.none(),
310310
offset: Layers.none(),
311-
causal?: false,
311+
causal: false,
312312
kernel_initializer: :glorot_uniform,
313313
attention_head_size: nil,
314314
dropout_rate: 0.0,
@@ -319,7 +319,7 @@ defmodule Bumblebee.Layers.Transformer do
319319
output_use_bias: true,
320320
block_type: :standard,
321321
layer_norm: [],
322-
scale_query?: true,
322+
scale_attention_weights: true,
323323
rotary_embedding: nil
324324
])
325325

@@ -328,7 +328,7 @@ defmodule Bumblebee.Layers.Transformer do
328328
num_key_value_heads = opts[:num_key_value_heads] || num_attention_heads
329329
hidden_size = opts[:hidden_size]
330330
ffn = opts[:ffn]
331-
causal? = opts[:causal?]
331+
causal = opts[:causal]
332332
kernel_initializer = opts[:kernel_initializer]
333333
attention_head_size = opts[:attention_head_size]
334334
dropout_rate = opts[:dropout_rate]
@@ -347,7 +347,7 @@ defmodule Bumblebee.Layers.Transformer do
347347
offset = opts[:offset]
348348
layer_norm = opts[:layer_norm]
349349
block_type = opts[:block_type]
350-
scale_query? = opts[:scale_query?]
350+
scale_attention_weights = opts[:scale_attention_weights]
351351
rotary_embedding = opts[:rotary_embedding]
352352

353353
ffn_fun =
@@ -393,7 +393,7 @@ defmodule Bumblebee.Layers.Transformer do
393393
attention_relative_bias: attention_relative_bias,
394394
attention_cache: self_attention_cache,
395395
offset: offset,
396-
causal?: causal?,
396+
causal: causal,
397397
num_heads: num_attention_heads,
398398
num_key_value_heads: num_key_value_heads,
399399
hidden_size: hidden_size,
@@ -404,7 +404,7 @@ defmodule Bumblebee.Layers.Transformer do
404404
key_use_bias: key_use_bias,
405405
value_use_bias: value_use_bias,
406406
output_use_bias: output_use_bias,
407-
scale_query?: scale_query?,
407+
scale_attention_weights: scale_attention_weights,
408408
rotary_embedding: rotary_embedding,
409409
name: join(name, "self_attention")
410410
)
@@ -448,7 +448,7 @@ defmodule Bumblebee.Layers.Transformer do
448448
key_use_bias: key_use_bias,
449449
value_use_bias: value_use_bias,
450450
output_use_bias: output_use_bias,
451-
scale_query?: scale_query?,
451+
scale_attention_weights: scale_attention_weights,
452452
rotary_embedding: rotary_embedding,
453453
name: join(name, "cross_attention")
454454
)
@@ -673,7 +673,7 @@ defmodule Bumblebee.Layers.Transformer do
673673
674674
* `:offset` - offset in the input sequence during iterative decoding
675675
676-
* `:causal?` - whether to apply causal attention mask, so that tokens
676+
* `:causal` - whether to apply causal attention mask, so that tokens
677677
are attended to only in a single direction. Defaults to `false`
678678
679679
* `:kernel_initializer` - initializer for kernel weights. Defaults
@@ -727,8 +727,8 @@ defmodule Bumblebee.Layers.Transformer do
727727
attention_relative_bias: Layers.none(),
728728
attention_cache: Layers.none(),
729729
offset: Layers.none(),
730-
causal?: false,
731-
scale_query?: true,
730+
causal: false,
731+
scale_attention_weights: true,
732732
kernel_initializer: :glorot_uniform,
733733
dropout_rate: 0.0,
734734
attention_head_size: nil,
@@ -749,8 +749,8 @@ defmodule Bumblebee.Layers.Transformer do
749749
num_key_value_heads = opts[:num_key_value_heads] || num_heads
750750
hidden_size = opts[:hidden_size]
751751
kernel_initializer = opts[:kernel_initializer]
752-
causal? = opts[:causal?]
753-
scale_query? = opts[:scale_query?]
752+
causal = opts[:causal]
753+
scale_attention_weights = opts[:scale_attention_weights]
754754
dropout_rate = opts[:dropout_rate]
755755
rotary_embedding = opts[:rotary_embedding]
756756

@@ -850,7 +850,7 @@ defmodule Bumblebee.Layers.Transformer do
850850
attention_mask = Layers.expand_attention_mask(attention_mask)
851851

852852
attention_mask =
853-
if causal? do
853+
if causal do
854854
Layers.Decoder.apply_causal_mask(attention_mask, query, offset)
855855
else
856856
attention_mask
@@ -884,7 +884,7 @@ defmodule Bumblebee.Layers.Transformer do
884884
end
885885

886886
attention_weights =
887-
Layers.attention_weights(query, key, attention_bias, scale_query?: scale_query?)
887+
Layers.attention_weights(query, key, attention_bias, scale: scale_attention_weights)
888888
|> Axon.dropout(rate: dropout_rate)
889889
|> Layers.apply_attention_head_mask(attention_head_mask)
890890

lib/bumblebee/multimodal/layout_lm.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ defmodule Bumblebee.Multimodal.LayoutLm do
491491
num_blocks: {"num_hidden_layers", number()},
492492
num_attention_heads: {"num_attention_heads", number()},
493493
intermediate_size: {"intermediate_size", number()},
494-
activation: {"hidden_act", atom()},
494+
activation: {"hidden_act", activation()},
495495
dropout_rate: {"hidden_dropout_prob", number()},
496496
attention_dropout_rate: {"attention_probs_dropout_prob", number()},
497497
classifier_dropout_rate: {"classifier_dropout", optional(number())},

lib/bumblebee/shared.ex

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,4 +512,49 @@ defmodule Bumblebee.Shared do
512512
end
513513
end
514514
end
515+
516+
@doc """
517+
Slices a subset of dense layer parameters.
518+
519+
Expects `out_template` to be a tuple representing a "shape" of the
520+
output units. The tuple should include a list in place of the axis
521+
along which the parameters are concatenated. The list should contain
522+
chunk sizes. `chunk_idx` indicates which chunk to slice.
523+
"""
524+
def sliced_dense_params_source(source_layer_name, out_template, chunk_idx) do
525+
out_template = Tuple.to_list(out_template)
526+
chunk_axis = Enum.find_index(out_template, &is_list/1)
527+
chunk_sizes = Enum.at(out_template, chunk_axis)
528+
{prev_chunk_sizes, [chunk_size | _]} = Enum.split(chunk_sizes, chunk_idx)
529+
offset = Enum.sum(prev_chunk_sizes)
530+
out_shape = List.replace_at(out_template, chunk_axis, Enum.sum(chunk_sizes))
531+
532+
%{
533+
"kernel" => {
534+
[{source_layer_name, "weight"}],
535+
fn [kernel] ->
536+
in_size = Nx.axis_size(kernel, -1)
537+
538+
kernel =
539+
kernel
540+
|> Nx.reshape(List.to_tuple(out_shape ++ [in_size]))
541+
|> Nx.slice_along_axis(offset, chunk_size, axis: chunk_axis)
542+
|> Nx.reshape({:auto, in_size})
543+
544+
# Transpose the kernel
545+
[out_features, in_features] = Nx.axes(kernel)
546+
Nx.transpose(kernel, axes: [in_features, out_features])
547+
end
548+
},
549+
"bias" => {
550+
[{source_layer_name, "bias"}],
551+
fn [bias] ->
552+
bias
553+
|> Nx.reshape(List.to_tuple(out_shape))
554+
|> Nx.slice_along_axis(offset, chunk_size, axis: chunk_axis)
555+
|> Nx.flatten()
556+
end
557+
}
558+
}
559+
end
515560
end

lib/bumblebee/shared/converters.ex

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,19 @@ defmodule Bumblebee.Shared.Converters do
203203
end
204204
end
205205
end
206+
207+
def activation() do
208+
mapping = %{
209+
"gelu_new" => :gelu_approx_tanh,
210+
"gelu_pytorch_tanh" => :gelu_approx_tanh,
211+
"quick_gelu" => :gelu_approx_sigmoid
212+
}
213+
214+
fn name, value ->
215+
case Map.fetch(mapping, value) do
216+
{:ok, replacement} -> {:ok, replacement}
217+
:error -> atom().(name, value)
218+
end
219+
end
220+
end
206221
end

lib/bumblebee/text/albert.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ defmodule Bumblebee.Text.Albert do
483483
block_depth: {"inner_group_num", number()},
484484
num_attention_heads: {"num_attention_heads", number()},
485485
intermediate_size: {"intermediate_size", number()},
486-
activation: {"hidden_act", atom()},
486+
activation: {"hidden_act", activation()},
487487
dropout_rate: {"hidden_dropout_prob", number()},
488488
attention_dropout_rate: {"attention_probs_dropout_prob", number()},
489489
classifier_dropout_rate: {"classifier_dropout_prob", optional(number())},

lib/bumblebee/text/bart.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ defmodule Bumblebee.Text.Bart do
589589
cross_attention_mask: encoder_attention_mask,
590590
cross_attention_head_mask: cross_attention_head_mask,
591591
cache: cache,
592-
causal?: true,
592+
causal: true,
593593
num_blocks: spec.decoder_num_blocks,
594594
num_attention_heads: spec.decoder_num_attention_heads,
595595
hidden_size: spec.hidden_size,
@@ -639,7 +639,7 @@ defmodule Bumblebee.Text.Bart do
639639
encoder_intermediate_size: {"encoder_ffn_dim", number()},
640640
decoder_intermediate_size: {"decoder_ffn_dim", number()},
641641
scale_embedding: {"scale_embedding", boolean()},
642-
activation: {"activation_function", atom()},
642+
activation: {"activation_function", activation()},
643643
dropout_rate: {"dropout", number()},
644644
attention_dropout_rate: {"attention_dropout", number()},
645645
activation_dropout_rate: {"activation_dropout", number()},

0 commit comments

Comments
 (0)