@@ -265,6 +265,10 @@ defmodule Bumblebee.Layers.Transformer do
265
265
* `:parallel` - block with attention and FFN independently (in parallel).
266
266
This type doesn't support cross-attention
267
267
268
+ Alternatively a custom 3-arity function may be given. The function
269
+ receives the input hidden state, a map with block steps and a
270
+ name to prefix any additional layers.
271
+
268
272
* `:scale_attention_weights` - whether to scale query in the traditional style of
269
273
multi-headed attention. Defaults to `true`
270
274
@@ -469,17 +473,25 @@ defmodule Bumblebee.Layers.Transformer do
469
473
470
474
ffn = & ffn_fun . ( & 1 , join ( name , "ffn" ) )
471
475
476
+ block_impl =
477
+ case block_type do
478
+ type when is_atom ( type ) -> & block_impl ( type , & 1 , & 2 , & 3 )
479
+ fun when is_function ( fun ) -> fun
480
+ end
481
+
472
482
{ hidden_state , attention_info , cross_attention_info } =
473
- block_impl (
474
- block_type ,
483
+ block_impl . (
475
484
hidden_state ,
476
- self_attention_norm ,
477
- self_attention ,
478
- cross_attention_maybe ,
479
- cross_attention_norm ,
480
- cross_attention ,
481
- output_norm ,
482
- ffn
485
+ % {
486
+ self_attention_norm: self_attention_norm ,
487
+ self_attention: self_attention ,
488
+ cross_attention_maybe: cross_attention_maybe ,
489
+ cross_attention_norm: cross_attention_norm ,
490
+ cross_attention: cross_attention ,
491
+ output_norm: output_norm ,
492
+ ffn: ffn
493
+ } ,
494
+ name
483
495
)
484
496
485
497
{ attention , self_attention_cache , attention_relative_bias } = attention_info
@@ -495,36 +507,26 @@ defmodule Bumblebee.Layers.Transformer do
495
507
{ hidden_state , attention , cross_attention , block_cache , attention_relative_bias }
496
508
end
497
509
498
- defp block_impl (
499
- :standard ,
500
- hidden_state ,
501
- self_attention_norm ,
502
- self_attention ,
503
- cross_attention_maybe ,
504
- cross_attention_norm ,
505
- cross_attention ,
506
- output_norm ,
507
- ffn
508
- ) do
510
+ defp block_impl ( :standard , hidden_state , steps , _name ) do
509
511
shortcut = hidden_state
510
512
511
- { hidden_state , attention_info } = self_attention . ( hidden_state )
513
+ { hidden_state , attention_info } = steps . self_attention . ( hidden_state )
512
514
513
515
hidden_state =
514
516
hidden_state
515
517
|> Axon . add ( shortcut )
516
- |> self_attention_norm . ( )
518
+ |> steps . self_attention_norm . ( )
517
519
518
520
{ hidden_state , cross_attention_info } =
519
- cross_attention_maybe . ( hidden_state , fn hidden_state ->
521
+ steps . cross_attention_maybe . ( hidden_state , fn hidden_state ->
520
522
shortcut = hidden_state
521
523
522
- { hidden_state , cross_attention_info } = cross_attention . ( hidden_state )
524
+ { hidden_state , cross_attention_info } = steps . cross_attention . ( hidden_state )
523
525
524
526
hidden_state =
525
527
hidden_state
526
528
|> Axon . add ( shortcut )
527
- |> cross_attention_norm . ( )
529
+ |> steps . cross_attention_norm . ( )
528
530
529
531
{ hidden_state , cross_attention_info }
530
532
end )
@@ -533,41 +535,31 @@ defmodule Bumblebee.Layers.Transformer do
533
535
534
536
hidden_state =
535
537
hidden_state
536
- |> ffn . ( )
538
+ |> steps . ffn . ( )
537
539
|> Axon . add ( shortcut )
538
- |> output_norm . ( )
540
+ |> steps . output_norm . ( )
539
541
540
542
{ hidden_state , attention_info , cross_attention_info }
541
543
end
542
544
543
- defp block_impl (
544
- :norm_first ,
545
- hidden_state ,
546
- self_attention_norm ,
547
- self_attention ,
548
- cross_attention_maybe ,
549
- cross_attention_norm ,
550
- cross_attention ,
551
- output_norm ,
552
- ffn
553
- ) do
545
+ defp block_impl ( :norm_first , hidden_state , steps , _name ) do
554
546
shortcut = hidden_state
555
547
556
548
{ hidden_state , attention_info } =
557
549
hidden_state
558
- |> self_attention_norm . ( )
559
- |> self_attention . ( )
550
+ |> steps . self_attention_norm . ( )
551
+ |> steps . self_attention . ( )
560
552
561
553
hidden_state = Axon . add ( hidden_state , shortcut )
562
554
563
555
{ hidden_state , cross_attention_info } =
564
- cross_attention_maybe . ( hidden_state , fn hidden_state ->
556
+ steps . cross_attention_maybe . ( hidden_state , fn hidden_state ->
565
557
shortcut = hidden_state
566
558
567
559
{ hidden_state , cross_attention_info } =
568
560
hidden_state
569
- |> cross_attention_norm . ( )
570
- |> cross_attention . ( )
561
+ |> steps . cross_attention_norm . ( )
562
+ |> steps . cross_attention . ( )
571
563
572
564
hidden_state = Axon . add ( hidden_state , shortcut )
573
565
@@ -578,40 +570,30 @@ defmodule Bumblebee.Layers.Transformer do
578
570
579
571
hidden_state =
580
572
hidden_state
581
- |> output_norm . ( )
582
- |> ffn . ( )
573
+ |> steps . output_norm . ( )
574
+ |> steps . ffn . ( )
583
575
|> Axon . add ( shortcut )
584
576
585
577
{ hidden_state , attention_info , cross_attention_info }
586
578
end
587
579
588
- defp block_impl (
589
- :parallel ,
590
- hidden_state ,
591
- self_attention_norm ,
592
- self_attention ,
593
- cross_attention_maybe ,
594
- _cross_attention_norm ,
595
- _cross_attention ,
596
- output_norm ,
597
- ffn
598
- ) do
580
+ defp block_impl ( :parallel , hidden_state , steps , _name ) do
599
581
shortcut = hidden_state
600
582
601
583
{ attention_hidden_state , attention_info } =
602
584
hidden_state
603
- |> self_attention_norm . ( )
604
- |> self_attention . ( )
585
+ |> steps . self_attention_norm . ( )
586
+ |> steps . self_attention . ( )
605
587
606
588
{ _hidden_state , cross_attention_info } =
607
- cross_attention_maybe . ( hidden_state , fn _hidden_state ->
589
+ steps . cross_attention_maybe . ( hidden_state , fn _hidden_state ->
608
590
raise "cross attention not supported"
609
591
end )
610
592
611
593
ffn_hidden_state =
612
594
hidden_state
613
- |> output_norm . ( )
614
- |> ffn . ( )
595
+ |> steps . output_norm . ( )
596
+ |> steps . ffn . ( )
615
597
616
598
hidden_state = Axon . add ( [ shortcut , attention_hidden_state , ffn_hidden_state ] )
617
599
0 commit comments