Skip to content

Commit 50d3ef1

Browse files
Deprecating TiedEmbeddingTransformerDecoder (#1815)
1 parent 543f698 commit 50d3ef1

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torchtune/modules/transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch import nn
1212
from torchtune.modules import MultiHeadAttention
1313
from torchtune.modules.attention_utils import _MaskType
14+
from torchtune.utils._logging import deprecated
1415

1516

1617
class TransformerSelfAttentionLayer(nn.Module):
@@ -619,6 +620,11 @@ def forward(
619620
return output
620621

621622

623+
@deprecated(
624+
msg="Please use torchtune.modules.TransformerDecoder instead. \
625+
If you need an example, see torchtune.models.qwen2._component_builders.py \
626+
on how to use torch.modules.TiedLinear for the output projection."
627+
)
622628
class TiedEmbeddingTransformerDecoder(nn.Module):
623629
"""
624630
Transformer Decoder with tied embedding weight. A key difference between

0 commit comments

Comments
 (0)