Skip to content

Migrate optional outputs to use global layer options #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions lib/bumblebee/audio/whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ defmodule Bumblebee.Audio.Whisper do
doc:
"the standard deviation of the normal initializer used for initializing kernel parameters"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions
])
]

@moduledoc """
Whisper model family.
Expand Down Expand Up @@ -161,6 +157,10 @@ defmodule Bumblebee.Audio.Whisper do
pass. The cache should be treated as opaque and initialized with
`Bumblebee.Text.Generation.init_cache/4`.

## Global layer options

#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}

## Configuration

#{Shared.options_doc(options)}
Expand Down Expand Up @@ -436,8 +436,6 @@ defmodule Bumblebee.Audio.Whisper do
activation: spec.activation
],
block_type: :norm_first,
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)

Expand Down Expand Up @@ -485,8 +483,6 @@ defmodule Bumblebee.Audio.Whisper do
activation: spec.activation
],
block_type: :norm_first,
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)

Expand Down
44 changes: 33 additions & 11 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -895,15 +895,22 @@ defmodule Bumblebee.Layers do
end

@doc """
Returns a container layer if `condition` is truthy, otherwise returns
a none layer.
Adds a layer that passes the input through only if the given global
layer option is set.
"""
def maybe_container(container, condition) do
if condition do
Axon.container(container)
else
none()
end
def global_opt_in(%Axon{} = input, global_option_name) do
Axon.layer(
fn input, opts ->
if opts[global_option_name] do
input
else
%Axon.None{}
end
end,
[input],
op_name: :global_opt_in,
global_options: [global_option_name]
)
end

@doc """
Expand Down Expand Up @@ -933,17 +940,32 @@ defmodule Bumblebee.Layers do

All values are wrapped with `Axon.optional/2`, so if any of them is
missing, it gets returned as `%Axon.None{}`.

Also, guards known optional outputs behind a global layer option
using `global_opt_in/2`.
"""
@spec output(map()) :: Axon.t()
def output(outputs) do
outputs
|> Map.new(fn
{key, %Axon{} = val} -> {key, Axon.optional(val)}
{key, val} -> {key, val}
|> Map.new(fn {key, %Axon{} = val} ->
{key, val |> maybe_opt_in_output(key) |> Axon.optional()}
end)
|> Axon.container()
end

@opt_in_outputs %{
:hidden_states => :output_hidden_states,
:attentions => :output_attentions
}

defp maybe_opt_in_output(%Axon{} = input, key) do
if option_name = @opt_in_outputs[key] do
global_opt_in(input, option_name)
else
input
end
end

@doc """
Computes a 1-full mask matching the first two dimensions of `input`
(batch size and sequence length).
Expand Down
20 changes: 4 additions & 16 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ defmodule Bumblebee.Layers.Transformer do
is configured, this option controls whether the bias from the
first block is used for all other blocks. Defaults to `false`

* `:output_hidden_states` - when `true`, the output includes a
tuple with intermediate hidden states from each transformer
block. Defaults to `false`

* `:output_attentions` - when `true`, the output includes a tuple
with attention weights from each transformer block. Defaults
to `false`

* `:name` - the prefix for layer names

For all other options (including required options) see `block/2`.
Expand Down Expand Up @@ -75,16 +67,12 @@ defmodule Bumblebee.Layers.Transformer do
cross_hidden_state: nil,
cross_attention_mask: Layers.none(),
cross_attention_head_mask: Layers.none(),
cache: Layers.none(),
output_hidden_states: false,
output_attentions: false
cache: Layers.none()
]
)

name = opts[:name]
num_blocks = opts[:num_blocks]
output_hidden_states = opts[:output_hidden_states]
output_attentions = opts[:output_attentions]

attention_mask = opts[:attention_mask]
attention_head_mask = opts[:attention_head_mask]
Expand All @@ -100,9 +88,9 @@ defmodule Bumblebee.Layers.Transformer do

state = %{
hidden_state: hidden_state,
hidden_states: Layers.maybe_container({hidden_state}, output_hidden_states),
attentions: Layers.maybe_container({}, output_attentions),
cross_attentions: Layers.maybe_container({}, output_attentions),
hidden_states: Axon.container({hidden_state}),
attentions: Axon.container({}),
cross_attentions: Axon.container({}),
cache: cache,
attention_relative_bias: Layers.none()
}
Expand Down
18 changes: 5 additions & 13 deletions lib/bumblebee/multimodal/blip.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ defmodule Bumblebee.Multimodal.Blip do
default: 2.6592,
doc: "the initial value for the scaling layer used to scale similarity logits"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions
])
]

@moduledoc """
The BLIP model for text-image similarity.
Expand Down Expand Up @@ -72,6 +68,10 @@ defmodule Bumblebee.Multimodal.Blip do
pass. The cache should be treated as opaque and initialized with
`Bumblebee.Text.Generation.init_cache/4`.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
## Configuration
#{Shared.options_doc(options)}
Expand Down Expand Up @@ -128,10 +128,6 @@ defmodule Bumblebee.Multimodal.Blip do

vision_model =
vision_spec
|> Bumblebee.configure(
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_hidden_states
)
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("vision_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
Expand All @@ -155,10 +151,6 @@ defmodule Bumblebee.Multimodal.Blip do

text_decoder =
text_spec
|> Bumblebee.configure(
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_hidden_states
)
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("text_decoder.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
Expand Down
18 changes: 5 additions & 13 deletions lib/bumblebee/multimodal/clip.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ defmodule Bumblebee.Multimodal.Clip do
default: 2.6592,
doc: "the initial value for the scaling layer used to scale similarity logits"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions
])
]

@moduledoc """
The CLIP model for text-image similarity.
Expand Down Expand Up @@ -54,6 +50,10 @@ defmodule Bumblebee.Multimodal.Clip do

Featurized image pixel values.

## Global layer options

#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}

## Configuration

#{Shared.options_doc(options)}
Expand Down Expand Up @@ -108,10 +108,6 @@ defmodule Bumblebee.Multimodal.Clip do

text_model =
text_spec
|> Bumblebee.configure(
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_hidden_states
)
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("text_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
Expand All @@ -122,10 +118,6 @@ defmodule Bumblebee.Multimodal.Clip do

vision_model =
vision_spec
|> Bumblebee.configure(
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_hidden_states
)
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("vision_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
Expand Down
14 changes: 5 additions & 9 deletions lib/bumblebee/multimodal/layout_lm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,7 @@ defmodule Bumblebee.Multimodal.LayoutLm do
default: 1.0e-12,
doc: "the epsilon used by the layer normalization layers"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions,
:num_labels,
:id_to_label
])
] ++ Shared.common_options([:num_labels, :id_to_label])

@moduledoc """
LayoutLM Model family.
Expand Down Expand Up @@ -140,6 +134,10 @@ defmodule Bumblebee.Multimodal.LayoutLm do
`{x0, y0, x1, y1}` where `{x0, y0}` is the upper left corner and
`{x1, y1}` is the lower right corner.

## Global layer options

#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}

## Configuration

#{Shared.options_doc(options)}
Expand Down Expand Up @@ -426,8 +424,6 @@ defmodule Bumblebee.Multimodal.LayoutLm do
intermediate_size: spec.intermediate_size,
activation: spec.activation
],
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)
end
Expand Down
16 changes: 16 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,22 @@ defmodule Bumblebee.Shared do
Enum.join(items, "\n\n")
end

@doc """
Generates documentation string for the given global layer options.
"""
@spec global_layer_options_doc(list(atom())) :: String.t()
def global_layer_options_doc(names) do
docs = [
output_hidden_states: "when `true`, the model output includes all hidden states",
output_attentions: "when `true`, the model output includes all attention weights"
]

Enum.map_join(names, "\n\n", fn name ->
doc = Keyword.fetch!(docs, name)
" * `#{inspect(name)}` - #{doc}"
end)
end

@doc """
Returns option defaults form the options specification.

Expand Down
16 changes: 7 additions & 9 deletions lib/bumblebee/text/albert.ex
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,7 @@ defmodule Bumblebee.Text.Albert do
doc:
"the standard deviation of the normal initializer used for initializing kernel parameters"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions,
:num_labels,
:id_to_label
])
] ++ Shared.common_options([:num_labels, :id_to_label])

@moduledoc """
ALBERT model family.
Expand Down Expand Up @@ -148,6 +142,10 @@ defmodule Bumblebee.Text.Albert do
The `:for_multiple_choice` model accepts groups of sequences, so the
expected sequence shape is `{batch_size, num_choices, sequence_length}`.

## Global layer options

#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}

## Configuration

#{Shared.options_doc(options)}
Expand Down Expand Up @@ -389,8 +387,8 @@ defmodule Bumblebee.Text.Albert do
name: join(name, "embedding_projection")
)

hidden_states = Layers.maybe_container({hidden_state}, spec.output_hidden_states)
attentions = Layers.maybe_container({}, spec.output_attentions)
hidden_states = Axon.container({hidden_state})
attentions = Axon.container({})

for block_idx <- 0..(spec.num_blocks - 1),
inner_idx <- 0..(spec.block_depth - 1),
Expand Down
15 changes: 5 additions & 10 deletions lib/bumblebee/text/bart.ex
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@ defmodule Bumblebee.Text.Bart do
"the standard deviation of the normal initializer used for initializing kernel parameters"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions,
:num_labels,
:id_to_label
]) ++
Shared.common_options([:num_labels, :id_to_label]) ++
Shared.token_options(
eos_token_id: 2,
decoder_start_token_id: 2
Expand Down Expand Up @@ -197,6 +192,10 @@ defmodule Bumblebee.Text.Bart do
`"position_ids"`, `"attention_head_mask"`, `"input_embeddings"`, `"encoder_hidden_state"`,
`"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`.

## Global layer options

#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}

## Configuration

#{Shared.options_doc(options)}
Expand Down Expand Up @@ -563,8 +562,6 @@ defmodule Bumblebee.Text.Bart do
intermediate_size: spec.encoder_intermediate_size,
activation: spec.activation
],
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)
end
Expand Down Expand Up @@ -603,8 +600,6 @@ defmodule Bumblebee.Text.Bart do
intermediate_size: spec.decoder_intermediate_size,
activation: spec.activation
],
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)
end
Expand Down
Loading