Skip to content

Add Starcoder #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ defmodule Bumblebee do
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
"GPT2Model" => {BumbleBee.Text.Gpt2, :base},
"GPTBigCodeModel" => {Bumblebee.Text.GptBigCode, :base},
"GPTBigCodeForCausalLM" => {Bumblebee.Text.GptBigCode, :for_causal_language_modeling},
"GPTBigCodeForSequenceClassification" =>
{Bumblebee.Text.GptBigCode, :for_sequence_classification},
"GPTBigCodeForTokenClassification" => {Bumblebee.Text.GptBigCode, :for_token_classification},
"GPTNeoXModel" => {Bumblebee.Text.GptNeoX, :base},
"GPTNeoXForCausalLM" => {Bumblebee.Text.GptNeoX, :for_causal_language_modeling},
"GPTNeoXForSequenceClassification" => {Bumblebee.Text.GptNeoX, :for_sequence_classification},
Expand Down Expand Up @@ -215,6 +220,7 @@ defmodule Bumblebee do
"clip" => Bumblebee.Text.ClipTokenizer,
"gpt_neox" => Bumblebee.Text.GptNeoXTokenizer,
"gpt2" => Bumblebee.Text.Gpt2Tokenizer,
"gpt_bigcode" => Bumblebee.Text.Gpt2Tokenizer,
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
"llama" => Bumblebee.Text.LlamaTokenizer,
"mistral" => Bumblebee.Text.LlamaTokenizer,
Expand Down
4 changes: 2 additions & 2 deletions lib/bumblebee/audio/whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ defmodule Bumblebee.Audio.Whisper do
cross_hidden_state: encoder_hidden_state,
cross_attention_head_mask: cross_attention_head_mask,
cache: cache,
causal?: true,
causal: true,
num_blocks: spec.decoder_num_blocks,
num_attention_heads: spec.decoder_num_attention_heads,
hidden_size: spec.hidden_size,
Expand Down Expand Up @@ -520,7 +520,7 @@ defmodule Bumblebee.Audio.Whisper do
decoder_num_attention_heads: {"decoder_attention_heads", number()},
encoder_intermediate_size: {"encoder_ffn_dim", number()},
decoder_intermediate_size: {"decoder_ffn_dim", number()},
activation: {"activation_function", atom()},
activation: {"activation_function", activation()},
dropout_rate: {"dropout", number()},
attention_dropout_rate: {"attention_dropout", number()},
activation_dropout_rate: {"activation_dropout", number()},
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/diffusion/unet_2d_conditional.ex
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do
num_attention_heads: {"attention_head_dim", one_of([number(), list(number())])},
cross_attention_size: {"cross_attention_dim", number()},
use_linear_projection: {"use_linear_projection", boolean()},
activation: {"act_fn", atom()},
activation: {"act_fn", activation()},
group_norm_num_groups: {"norm_num_groups", number()},
group_norm_epsilon: {"norm_eps", number()}
)
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/diffusion/vae_kl.ex
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ defmodule Bumblebee.Diffusion.VaeKl do
down_block_types:
{"down_block_types", list(mapping(%{"DownEncoderBlock2D" => :down_block}))},
up_block_types: {"up_block_types", list(mapping(%{"UpDecoderBlock2D" => :up_block}))},
activation: {"act_fn", atom()}
activation: {"act_fn", activation()}
)

@for.config(spec, opts)
Expand Down
37 changes: 25 additions & 12 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Bumblebee.Layers do

import Nx.Defn

@unsupported_activations [:gelu_new, :quick_gelu]
@unsupported_activations [:gelu_approx_tanh, :gelu_approx_sigmoid]

@pi :math.pi()

Expand All @@ -30,17 +30,29 @@ defmodule Bumblebee.Layers do
end

@doc """
Implements the GeLU new activation from huggingface/transformers.
Implements the GeLU activation approximated with tanh.

## References

* [Gaussian Error Linear Units (GeLUs)](https://arxiv.org/pdf/1606.08415.pdf)

"""
defn gelu_new(input, _opts \\ []) do
defn gelu_approx_tanh(input, _opts \\ []) do
0.5 * input *
(1.0 + Nx.tanh(Nx.sqrt(2.0 / @pi) * (input + 0.044715 * Nx.pow(input, 3.0))))
end

@doc """
Implements the GeLU quick activation from huggingface/transformers.
Implements the GeLU activation approximated with sigmoid.

Note that this approximation is less accurate than `gelu_approx_tanh/2`.

## References

* [Gaussian Error Linear Units (GeLUs)](https://arxiv.org/pdf/1606.08415.pdf)

"""
defn quick_gelu(input, _opts \\ []) do
defn gelu_approx_sigmoid(input, _opts \\ []) do
input * Nx.sigmoid(1.702 * input)
end

Expand Down Expand Up @@ -184,28 +196,29 @@ defmodule Bumblebee.Layers do

## Options

* `:scale_query?` - whether to scale the query. Defaults to `true`
* `:scale` - whether to scale the weights. Defaults to `true`

"""
def attention_weights(query, key, bias, opts \\ []) do
Axon.layer(&attention_weights_impl/4, [query, key, bias], opts)
end

defnp attention_weights_impl(query, key, bias, opts \\ []) do
opts = keyword!(opts, mode: :train, scale_query?: true)
opts = keyword!(opts, mode: :train, scale: true)

key = Nx.transpose(key, axes: [0, 2, 1, 3])
query = Nx.transpose(query, axes: [0, 2, 1, 3])

query =
if opts[:scale_query?] do
weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])

weights =
if opts[:scale] do
depth = Nx.axis_size(query, -1)
query / Nx.sqrt(depth)
weights / Nx.sqrt(depth)
else
query
weights
end

weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])
weights = weights + bias
Axon.Activations.softmax(weights, axis: -1)
end
Expand Down
36 changes: 18 additions & 18 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ defmodule Bumblebee.Layers.Transformer do
block_opts_keys = [
:num_attention_heads,
:num_key_value_heads,
:causal?,
:causal,
:hidden_size,
:ffn,
:kernel_initializer,
Expand All @@ -56,7 +56,7 @@ defmodule Bumblebee.Layers.Transformer do
:output_use_bias,
:layer_norm,
:block_type,
:scale_query?,
:scale_attention_weights,
:rotary_embedding
]

Expand Down Expand Up @@ -216,7 +216,7 @@ defmodule Bumblebee.Layers.Transformer do

* `:offset` - offset in the input sequence during iterative decoding

* `:causal?` - whether the self-attention block should be causal.
* `:causal` - whether the self-attention block should be causal.
Defaults to `false`

* `:kernel_initializer` - initializer for kernel weights. Defaults
Expand Down Expand Up @@ -265,7 +265,7 @@ defmodule Bumblebee.Layers.Transformer do
* `:parallel` - block with attention and FFN independently (in parallel).
This type doesn't support cross-attention

* `:scale_query?` - whether to scale query in the traditional style of
* `:scale_attention_weights` - whether to scale query in the traditional style of
multi-headed attention. Defaults to `true`

* `:rotary_embedding` - configuration of rotary embedding. If set,
Expand Down Expand Up @@ -308,7 +308,7 @@ defmodule Bumblebee.Layers.Transformer do
cross_attention_head_mask: Layers.none(),
block_cache: Layers.none(),
offset: Layers.none(),
causal?: false,
causal: false,
kernel_initializer: :glorot_uniform,
attention_head_size: nil,
dropout_rate: 0.0,
Expand All @@ -319,7 +319,7 @@ defmodule Bumblebee.Layers.Transformer do
output_use_bias: true,
block_type: :standard,
layer_norm: [],
scale_query?: true,
scale_attention_weights: true,
rotary_embedding: nil
])

Expand All @@ -328,7 +328,7 @@ defmodule Bumblebee.Layers.Transformer do
num_key_value_heads = opts[:num_key_value_heads] || num_attention_heads
hidden_size = opts[:hidden_size]
ffn = opts[:ffn]
causal? = opts[:causal?]
causal = opts[:causal]
kernel_initializer = opts[:kernel_initializer]
attention_head_size = opts[:attention_head_size]
dropout_rate = opts[:dropout_rate]
Expand All @@ -347,7 +347,7 @@ defmodule Bumblebee.Layers.Transformer do
offset = opts[:offset]
layer_norm = opts[:layer_norm]
block_type = opts[:block_type]
scale_query? = opts[:scale_query?]
scale_attention_weights = opts[:scale_attention_weights]
rotary_embedding = opts[:rotary_embedding]

ffn_fun =
Expand Down Expand Up @@ -393,7 +393,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_relative_bias: attention_relative_bias,
attention_cache: self_attention_cache,
offset: offset,
causal?: causal?,
causal: causal,
num_heads: num_attention_heads,
num_key_value_heads: num_key_value_heads,
hidden_size: hidden_size,
Expand All @@ -404,7 +404,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: key_use_bias,
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
scale_query?: scale_query?,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
name: join(name, "self_attention")
)
Expand Down Expand Up @@ -448,7 +448,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: key_use_bias,
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
scale_query?: scale_query?,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
name: join(name, "cross_attention")
)
Expand Down Expand Up @@ -673,7 +673,7 @@ defmodule Bumblebee.Layers.Transformer do

* `:offset` - offset in the input sequence during iterative decoding

* `:causal?` - whether to apply causal attention mask, so that tokens
* `:causal` - whether to apply causal attention mask, so that tokens
are attended to only in a single direction. Defaults to `false`

* `:kernel_initializer` - initializer for kernel weights. Defaults
Expand Down Expand Up @@ -727,8 +727,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_relative_bias: Layers.none(),
attention_cache: Layers.none(),
offset: Layers.none(),
causal?: false,
scale_query?: true,
causal: false,
scale_attention_weights: true,
kernel_initializer: :glorot_uniform,
dropout_rate: 0.0,
attention_head_size: nil,
Expand All @@ -749,8 +749,8 @@ defmodule Bumblebee.Layers.Transformer do
num_key_value_heads = opts[:num_key_value_heads] || num_heads
hidden_size = opts[:hidden_size]
kernel_initializer = opts[:kernel_initializer]
causal? = opts[:causal?]
scale_query? = opts[:scale_query?]
causal = opts[:causal]
scale_attention_weights = opts[:scale_attention_weights]
dropout_rate = opts[:dropout_rate]
rotary_embedding = opts[:rotary_embedding]

Expand Down Expand Up @@ -850,7 +850,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_mask = Layers.expand_attention_mask(attention_mask)

attention_mask =
if causal? do
if causal do
Layers.Decoder.apply_causal_mask(attention_mask, query, offset)
else
attention_mask
Expand Down Expand Up @@ -884,7 +884,7 @@ defmodule Bumblebee.Layers.Transformer do
end

attention_weights =
Layers.attention_weights(query, key, attention_bias, scale_query?: scale_query?)
Layers.attention_weights(query, key, attention_bias, scale: scale_attention_weights)
|> Axon.dropout(rate: dropout_rate)
|> Layers.apply_attention_head_mask(attention_head_mask)

Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/multimodal/layout_lm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ defmodule Bumblebee.Multimodal.LayoutLm do
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", atom()},
activation: {"hidden_act", activation()},
dropout_rate: {"hidden_dropout_prob", number()},
attention_dropout_rate: {"attention_probs_dropout_prob", number()},
classifier_dropout_rate: {"classifier_dropout", optional(number())},
Expand Down
45 changes: 45 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -512,4 +512,49 @@ defmodule Bumblebee.Shared do
end
end
end

@doc """
Slices a subset of dense layer parameters.

Expects `out_template` to be a tuple representing a "shape" of the
output units. The tuple should include a list in place of the axis
along which the parameters are concatenated. The list should contain
chunk sizes. `chunk_idx` indicates which chunk to slice.
"""
def sliced_dense_params_source(source_layer_name, out_template, chunk_idx) do
out_template = Tuple.to_list(out_template)
chunk_axis = Enum.find_index(out_template, &is_list/1)
chunk_sizes = Enum.at(out_template, chunk_axis)
{prev_chunk_sizes, [chunk_size | _]} = Enum.split(chunk_sizes, chunk_idx)
offset = Enum.sum(prev_chunk_sizes)
out_shape = List.replace_at(out_template, chunk_axis, Enum.sum(chunk_sizes))

%{
"kernel" => {
[{source_layer_name, "weight"}],
fn [kernel] ->
in_size = Nx.axis_size(kernel, -1)

kernel =
kernel
|> Nx.reshape(List.to_tuple(out_shape ++ [in_size]))
|> Nx.slice_along_axis(offset, chunk_size, axis: chunk_axis)
|> Nx.reshape({:auto, in_size})

# Transpose the kernel
[out_features, in_features] = Nx.axes(kernel)
Nx.transpose(kernel, axes: [in_features, out_features])
end
},
"bias" => {
[{source_layer_name, "bias"}],
fn [bias] ->
bias
|> Nx.reshape(List.to_tuple(out_shape))
|> Nx.slice_along_axis(offset, chunk_size, axis: chunk_axis)
|> Nx.flatten()
end
}
}
end
end
15 changes: 15 additions & 0 deletions lib/bumblebee/shared/converters.ex
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,19 @@ defmodule Bumblebee.Shared.Converters do
end
end
end

def activation() do
mapping = %{
"gelu_new" => :gelu_approx_tanh,
"gelu_pytorch_tanh" => :gelu_approx_tanh,
"quick_gelu" => :gelu_approx_sigmoid
}

fn name, value ->
case Map.fetch(mapping, value) do
{:ok, replacement} -> {:ok, replacement}
:error -> atom().(name, value)
end
end
end
end
2 changes: 1 addition & 1 deletion lib/bumblebee/text/albert.ex
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ defmodule Bumblebee.Text.Albert do
block_depth: {"inner_group_num", number()},
num_attention_heads: {"num_attention_heads", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", atom()},
activation: {"hidden_act", activation()},
dropout_rate: {"hidden_dropout_prob", number()},
attention_dropout_rate: {"attention_probs_dropout_prob", number()},
classifier_dropout_rate: {"classifier_dropout_prob", optional(number())},
Expand Down
4 changes: 2 additions & 2 deletions lib/bumblebee/text/bart.ex
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ defmodule Bumblebee.Text.Bart do
cross_attention_mask: encoder_attention_mask,
cross_attention_head_mask: cross_attention_head_mask,
cache: cache,
causal?: true,
causal: true,
num_blocks: spec.decoder_num_blocks,
num_attention_heads: spec.decoder_num_attention_heads,
hidden_size: spec.hidden_size,
Expand Down Expand Up @@ -639,7 +639,7 @@ defmodule Bumblebee.Text.Bart do
encoder_intermediate_size: {"encoder_ffn_dim", number()},
decoder_intermediate_size: {"decoder_ffn_dim", number()},
scale_embedding: {"scale_embedding", boolean()},
activation: {"activation_function", atom()},
activation: {"activation_function", activation()},
dropout_rate: {"dropout", number()},
attention_dropout_rate: {"attention_dropout", number()},
activation_dropout_rate: {"activation_dropout", number()},
Expand Down
Loading