@@ -234,10 +234,11 @@ def __init__(self, args, layer_id):
234
234
self .ffn = RWKV_ChannelMix (args , layer_id )
235
235
236
236
if args .tiny_att_dim > 0 and self .layer_id == args .tiny_att_layer :
237
- self .head_q = nn .Linear (args .n_embd , args .tiny_att_dim , bias = False )
238
- self .head_k = nn .Linear (args .n_embd , args .tiny_att_dim , bias = False )
239
- self .head_v = nn .Linear (args .n_embd , args .n_embd , bias = False )
240
- self .register_buffer ("head_mask" , torch .tril (torch .ones (args .ctx_len , args .ctx_len )))
237
+ self .tiny_ln = nn .LayerNorm (args .n_embd )
238
+ self .tiny_q = nn .Linear (args .n_embd , args .tiny_att_dim , bias = False )
239
+ self .tiny_k = nn .Linear (args .n_embd , args .tiny_att_dim , bias = False )
240
+ self .tiny_v = nn .Linear (args .n_embd , args .n_embd , bias = False )
241
+ self .register_buffer ("tiny_mask" , torch .tril (torch .ones (args .ctx_len , args .ctx_len )))
241
242
242
243
def forward (self , x , x_emb = None ):
243
244
args = self .args
@@ -255,11 +256,12 @@ def forward(self, x, x_emb=None):
255
256
x = x + self .ffn (self .ln2 (x ))
256
257
257
258
if args .tiny_att_dim > 0 and self .layer_id == args .tiny_att_layer :
258
- q = self .head_q (x )[:, :T , :]
259
- k = self .head_k (x )[:, :T , :]
260
- c = (q @ k .transpose (- 2 , - 1 )) * (1.0 / args .tiny_att_downscale )
261
- c = c .masked_fill (self .head_mask [:T , :T ] == 0 , 0 )
262
- x = x + c @ self .head_v (x_emb )
259
+ xx = self .tiny_ln (x )
260
+ q = self .tiny_q (xx )[:, :T , :]
261
+ k = self .tiny_k (xx )[:, :T , :]
262
+ c = (q @ k .transpose (- 2 , - 1 )) * (args .tiny_att_dim ** (- 0.5 ))
263
+ c = c .masked_fill (self .tiny_mask [:T , :T ] == 0 , 0 )
264
+ x = x + c @ self .tiny_v (x_emb )
263
265
return x
264
266
265
267
0 commit comments