Skip to content

Commit a268cd2

Browse files
committed
better tinyAtt
1 parent de8bae7 commit a268cd2

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

RWKV-v4neo/src/model.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,11 @@ def __init__(self, args, layer_id):
234234
self.ffn = RWKV_ChannelMix(args, layer_id)
235235

236236
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
237-
self.head_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
238-
self.head_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
239-
self.head_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
240-
self.register_buffer("head_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
237+
self.tiny_ln = nn.LayerNorm(args.n_embd)
238+
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
239+
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
240+
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
241+
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
241242

242243
def forward(self, x, x_emb=None):
243244
args = self.args
@@ -255,11 +256,12 @@ def forward(self, x, x_emb=None):
255256
x = x + self.ffn(self.ln2(x))
256257

257258
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
258-
q = self.head_q(x)[:, :T, :]
259-
k = self.head_k(x)[:, :T, :]
260-
c = (q @ k.transpose(-2, -1)) * (1.0 / args.tiny_att_downscale)
261-
c = c.masked_fill(self.head_mask[:T, :T] == 0, 0)
262-
x = x + c @ self.head_v(x_emb)
259+
xx = self.tiny_ln(x)
260+
q = self.tiny_q(xx)[:, :T, :]
261+
k = self.tiny_k(xx)[:, :T, :]
262+
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
263+
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
264+
x = x + c @ self.tiny_v(x_emb)
263265
return x
264266

265267

RWKV-v4neo/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
7171
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
7272
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
73-
parser.add_argument("--tiny_att_downscale", default=0, type=float)
7473

7574
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
7675
parser.add_argument("--lr_final", default=1e-5, type=float)

0 commit comments

Comments
 (0)