Skip to content

Commit f739e0a

Browse files
Add DINOv2 model (#334)
Co-authored-by: Jonatan Kłosko <[email protected]>
1 parent 5b3d7ac commit f739e0a

File tree

11 files changed

+952
-73
lines changed

11 files changed

+952
-73
lines changed

lib/bumblebee.ex

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ defmodule Bumblebee do
123123
{Bumblebee.Vision.Deit, :for_image_classification_with_teacher},
124124
"DeiTForMaskedImageModeling" => {Bumblebee.Vision.Deit, :for_masked_image_modeling},
125125
"DeiTModel" => {Bumblebee.Vision.Deit, :base},
126+
"Dinov2Model" => {Bumblebee.Vision.DinoV2, :base},
127+
"Dinov2Backbone" => {Bumblebee.Vision.DinoV2, :backbone},
128+
"Dinov2ForImageClassification" => {Bumblebee.Vision.DinoV2, :for_image_classification},
126129
"DistilBertModel" => {Bumblebee.Text.Distilbert, :base},
127130
"DistilBertForMaskedLM" => {Bumblebee.Text.Distilbert, :for_masked_language_modeling},
128131
"DistilBertForSequenceClassification" =>
@@ -203,7 +206,8 @@ defmodule Bumblebee do
203206
}
204207

205208
@transformers_image_processor_type_to_featurizer %{
206-
"BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer
209+
"BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer,
210+
"BitImageProcessor" => Bumblebee.Vision.BitFeaturizer
207211
}
208212

209213
@model_type_to_featurizer %{

lib/bumblebee/diffusion/layers/unet.ex

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ defmodule Bumblebee.Diffusion.Layers.UNet do
323323
epsilon: 1.0e-5
324324
],
325325
dropout_rate: dropout,
326-
ffn: &ffn_geglu(&1, hidden_size, dropout: dropout, name: &2),
326+
ffn: &ffn_geglu(&1, 4 * hidden_size, hidden_size, dropout: dropout, name: &2),
327327
block_type: :norm_first,
328328
name: join(name, "blocks")
329329
)
@@ -347,12 +347,10 @@ defmodule Bumblebee.Diffusion.Layers.UNet do
347347
end
348348

349349
# A feed-forward network with GEGLU nonlinearity as in https://arxiv.org/abs/2002.05202
350-
defp ffn_geglu(x, size, opts) do
350+
defp ffn_geglu(x, intermediate_size, output_size, opts) do
351351
name = opts[:name]
352352
dropout = opts[:dropout] || 0.0
353353

354-
intermediate_size = 4 * size
355-
356354
{x, gate} =
357355
x
358356
|> Axon.dense(intermediate_size * 2, name: join(name, "intermediate"))
@@ -362,6 +360,6 @@ defmodule Bumblebee.Diffusion.Layers.UNet do
362360

363361
x
364362
|> Axon.dropout(rate: dropout, name: join(name, "dropout"))
365-
|> Axon.dense(size, name: join(name, "output"))
363+
|> Axon.dense(output_size, name: join(name, "output"))
366364
end
367365
end

lib/bumblebee/layers/transformer.ex

Lines changed: 43 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ defmodule Bumblebee.Layers.Transformer do
265265
* `:parallel` - block with attention and FFN independently (in parallel).
266266
This type doesn't support cross-attention
267267
268+
Alternatively a custom 3-arity function may be given. The function
269+
receives the input hidden state, a map with block steps and a
270+
name to prefix any additional layers.
271+
268272
* `:scale_attention_weights` - whether to scale query in the traditional style of
269273
multi-headed attention. Defaults to `true`
270274
@@ -469,17 +473,25 @@ defmodule Bumblebee.Layers.Transformer do
469473

470474
ffn = &ffn_fun.(&1, join(name, "ffn"))
471475

476+
block_impl =
477+
case block_type do
478+
type when is_atom(type) -> &block_impl(type, &1, &2, &3)
479+
fun when is_function(fun) -> fun
480+
end
481+
472482
{hidden_state, attention_info, cross_attention_info} =
473-
block_impl(
474-
block_type,
483+
block_impl.(
475484
hidden_state,
476-
self_attention_norm,
477-
self_attention,
478-
cross_attention_maybe,
479-
cross_attention_norm,
480-
cross_attention,
481-
output_norm,
482-
ffn
485+
%{
486+
self_attention_norm: self_attention_norm,
487+
self_attention: self_attention,
488+
cross_attention_maybe: cross_attention_maybe,
489+
cross_attention_norm: cross_attention_norm,
490+
cross_attention: cross_attention,
491+
output_norm: output_norm,
492+
ffn: ffn
493+
},
494+
name
483495
)
484496

485497
{attention, self_attention_cache, attention_relative_bias} = attention_info
@@ -495,36 +507,26 @@ defmodule Bumblebee.Layers.Transformer do
495507
{hidden_state, attention, cross_attention, block_cache, attention_relative_bias}
496508
end
497509

498-
defp block_impl(
499-
:standard,
500-
hidden_state,
501-
self_attention_norm,
502-
self_attention,
503-
cross_attention_maybe,
504-
cross_attention_norm,
505-
cross_attention,
506-
output_norm,
507-
ffn
508-
) do
510+
defp block_impl(:standard, hidden_state, steps, _name) do
509511
shortcut = hidden_state
510512

511-
{hidden_state, attention_info} = self_attention.(hidden_state)
513+
{hidden_state, attention_info} = steps.self_attention.(hidden_state)
512514

513515
hidden_state =
514516
hidden_state
515517
|> Axon.add(shortcut)
516-
|> self_attention_norm.()
518+
|> steps.self_attention_norm.()
517519

518520
{hidden_state, cross_attention_info} =
519-
cross_attention_maybe.(hidden_state, fn hidden_state ->
521+
steps.cross_attention_maybe.(hidden_state, fn hidden_state ->
520522
shortcut = hidden_state
521523

522-
{hidden_state, cross_attention_info} = cross_attention.(hidden_state)
524+
{hidden_state, cross_attention_info} = steps.cross_attention.(hidden_state)
523525

524526
hidden_state =
525527
hidden_state
526528
|> Axon.add(shortcut)
527-
|> cross_attention_norm.()
529+
|> steps.cross_attention_norm.()
528530

529531
{hidden_state, cross_attention_info}
530532
end)
@@ -533,41 +535,31 @@ defmodule Bumblebee.Layers.Transformer do
533535

534536
hidden_state =
535537
hidden_state
536-
|> ffn.()
538+
|> steps.ffn.()
537539
|> Axon.add(shortcut)
538-
|> output_norm.()
540+
|> steps.output_norm.()
539541

540542
{hidden_state, attention_info, cross_attention_info}
541543
end
542544

543-
defp block_impl(
544-
:norm_first,
545-
hidden_state,
546-
self_attention_norm,
547-
self_attention,
548-
cross_attention_maybe,
549-
cross_attention_norm,
550-
cross_attention,
551-
output_norm,
552-
ffn
553-
) do
545+
defp block_impl(:norm_first, hidden_state, steps, _name) do
554546
shortcut = hidden_state
555547

556548
{hidden_state, attention_info} =
557549
hidden_state
558-
|> self_attention_norm.()
559-
|> self_attention.()
550+
|> steps.self_attention_norm.()
551+
|> steps.self_attention.()
560552

561553
hidden_state = Axon.add(hidden_state, shortcut)
562554

563555
{hidden_state, cross_attention_info} =
564-
cross_attention_maybe.(hidden_state, fn hidden_state ->
556+
steps.cross_attention_maybe.(hidden_state, fn hidden_state ->
565557
shortcut = hidden_state
566558

567559
{hidden_state, cross_attention_info} =
568560
hidden_state
569-
|> cross_attention_norm.()
570-
|> cross_attention.()
561+
|> steps.cross_attention_norm.()
562+
|> steps.cross_attention.()
571563

572564
hidden_state = Axon.add(hidden_state, shortcut)
573565

@@ -578,40 +570,30 @@ defmodule Bumblebee.Layers.Transformer do
578570

579571
hidden_state =
580572
hidden_state
581-
|> output_norm.()
582-
|> ffn.()
573+
|> steps.output_norm.()
574+
|> steps.ffn.()
583575
|> Axon.add(shortcut)
584576

585577
{hidden_state, attention_info, cross_attention_info}
586578
end
587579

588-
defp block_impl(
589-
:parallel,
590-
hidden_state,
591-
self_attention_norm,
592-
self_attention,
593-
cross_attention_maybe,
594-
_cross_attention_norm,
595-
_cross_attention,
596-
output_norm,
597-
ffn
598-
) do
580+
defp block_impl(:parallel, hidden_state, steps, _name) do
599581
shortcut = hidden_state
600582

601583
{attention_hidden_state, attention_info} =
602584
hidden_state
603-
|> self_attention_norm.()
604-
|> self_attention.()
585+
|> steps.self_attention_norm.()
586+
|> steps.self_attention.()
605587

606588
{_hidden_state, cross_attention_info} =
607-
cross_attention_maybe.(hidden_state, fn _hidden_state ->
589+
steps.cross_attention_maybe.(hidden_state, fn _hidden_state ->
608590
raise "cross attention not supported"
609591
end)
610592

611593
ffn_hidden_state =
612594
hidden_state
613-
|> output_norm.()
614-
|> ffn.()
595+
|> steps.output_norm.()
596+
|> steps.ffn.()
615597

616598
hidden_state = Axon.add([shortcut, attention_hidden_state, ffn_hidden_state])
617599

0 commit comments

Comments
 (0)