forked from dgaddy/silent_speech
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharchitecture.py
More file actions
442 lines (389 loc) · 16.8 KB
/
architecture.py
File metadata and controls
442 lines (389 loc) · 16.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers.weight_init import trunc_normal_
class LearnedRelativePositionalEmbedding(nn.Module):
# from https://github.com/pytorch/fairseq/pull/2225/commits/a7fb63f2b84d5b20c8855e9c3372a95e5d0ea073
"""
This module learns relative positional embeddings up to a fixed
maximum size. These are masked for decoder and unmasked for encoder
self attention.
By default the embeddings are added to keys, but could be added to
values as well.
Args:
max_relative_pos (int): the maximum relative positions to compute embeddings for
num_heads (int): number of attention heads
embedding_dim (int): depth of embeddings
unmasked (bool): if the attention is unmasked (for transformer encoder)
heads_share_embeddings (bool): if heads share the same relative positional embeddings
add_to_values (bool): compute embeddings to be added to values as well
"""
def __init__(
self,
max_relative_pos: int,
num_heads: int,
embedding_dim: int,
unmasked: bool = False,
heads_share_embeddings: bool = False,
add_to_values: bool = False,
):
super().__init__()
self.max_relative_pos = max_relative_pos
self.num_heads = num_heads
self.embedding_dim = embedding_dim
self.unmasked = unmasked
self.heads_share_embeddings = heads_share_embeddings
self.add_to_values = add_to_values
num_embeddings = 2 * max_relative_pos - 1 if unmasked else max_relative_pos
embedding_size = (
[num_embeddings, embedding_dim, 1]
if heads_share_embeddings
else [num_heads, num_embeddings, embedding_dim, 1]
)
if add_to_values:
embedding_size[-1] = 2
initial_stddev = embedding_dim ** (-0.5)
self.embeddings = nn.Parameter(torch.zeros(*embedding_size))
nn.init.normal_(self.embeddings, mean=0.0, std=initial_stddev)
def forward(self, query: torch.Tensor, saved_state : bool = None):
"""
Computes relative positional embeddings to be added to keys (and optionally values),
multiplies the embeddings for keys with queries to create positional logits,
returns the positional logits, along with embeddings for values (optionally)
which could be added to values outside this module.
Args:
query (torch.Tensor): query tensor
saved_state (dict): saved state from previous time step
Shapes:
query: `(length, batch_size*num_heads, embed_dim)`
Returns:
tuple(torch.Tensor):
- positional logits
- relative positional embeddings to be added to values
"""
# During inference when previous states are cached
if saved_state is not None and "prev_key" in saved_state:
assert not self.unmasked, "This should only be for decoder attention"
length = saved_state["prev_key"].shape[-2] + 1 # `length - 1` keys are cached,
# `+ 1` for the current time step
decoder_step = True
else:
length = query.shape[0]
decoder_step = False
used_embeddings = self.get_embeddings_for_query(length)
values_embeddings = used_embeddings[..., 1] if self.add_to_values else None
positional_logits = self.calculate_positional_logits(query, used_embeddings[..., 0])
positional_logits = self.relative_to_absolute_indexing(positional_logits, decoder_step)
return (positional_logits, values_embeddings)
def get_embeddings_for_query(self, length: int) -> torch.Tensor:
"""
Extract the required embeddings. The maximum relative position between two time steps is
`length` for masked case or `2*length - 1` for the unmasked case. If `length` is greater than
`max_relative_pos`, we first pad the embeddings tensor with zero-embeddings, which represent
embeddings when relative position is greater than `max_relative_pos`. In case `length` is
less than `max_relative_pos`, we don't use the first `max_relative_pos - length embeddings`.
Args:
length (int): length of the query
Returns:
torch.Tensor: embeddings used by the query
"""
pad_length = max(length - self.max_relative_pos, 0)
start_pos = max(self.max_relative_pos - length, 0)
if self.unmasked:
with torch.no_grad():
padded_embeddings = nn.functional.pad(self.embeddings, (0, 0, 0, 0, pad_length, pad_length))
used_embeddings = padded_embeddings.narrow(-3, start_pos, 2 * length - 1)
else:
with torch.no_grad():
padded_embeddings = nn.functional.pad(self.embeddings, (0, 0, 0, 0, pad_length, 0))
used_embeddings = padded_embeddings.narrow(-3, start_pos, length)
return used_embeddings
def calculate_positional_logits(self, query: torch.Tensor, relative_embeddings: torch.Tensor) -> torch.Tensor:
"""
Multiplies query with the relative positional embeddings to create relative
positional logits
Args:
query (torch.Tensor): Input tensor representing queries
relative_embeddings (torch.Tensor): relative embeddings compatible with query
Shapes:
query: `(length, batch_size*num_heads, embed_dim)` if heads share embeddings
else `(length, batch_size, num_heads, embed_dim)`
relative_embeddings: `(max_allowed_relative_positions, embed_dim)` if heads share embeddings
else `(num_heads, max_allowed_relative_positions, embed_dim)`
where `max_allowed_relative_positions` is `length` if masked
else `2*length - 1`
Returns:
torch.Tensor: relative positional logits
"""
if self.heads_share_embeddings:
positional_logits = torch.einsum("lbd,md->lbm", query, relative_embeddings)
else:
query = query.view(query.shape[0], -1, self.num_heads, self.embedding_dim)
positional_logits = torch.einsum("lbhd,hmd->lbhm", query, relative_embeddings)
positional_logits = positional_logits.contiguous().view(
positional_logits.shape[0], -1, positional_logits.shape[-1]
)
# mask out tokens out of range
length = query.size(0)
if length > self.max_relative_pos:
# there is some padding
pad_length = length - self.max_relative_pos
positional_logits[:, :, :pad_length] -= 1e8
if self.unmasked:
positional_logits[:, :, -pad_length:] -= 1e8
return positional_logits
def relative_to_absolute_indexing(self, x: torch.Tensor, decoder_step: bool) -> torch.Tensor:
"""
Index tensor x (relative positional logits) in terms of absolute positions
rather than relative positions. Last dimension of x represents relative position
with respect to the first dimension, whereas returned tensor has both the first
and last dimension indexed with absolute positions.
Args:
x (torch.Tensor): positional logits indexed by relative positions
decoder_step (bool): is this is a single decoder step (during inference)
Shapes:
x: `(length, batch_size*num_heads, length)` for masked case or
`(length, batch_size*num_heads, 2*length - 1)` for unmasked
Returns:
torch.Tensor: positional logits represented using absolute positions
"""
length, bsz_heads, _ = x.shape
if decoder_step:
return x.contiguous().view(bsz_heads, 1, -1)
if self.unmasked:
x = nn.functional.pad(x, (0, 1))
x = x.transpose(0, 1)
x = x.contiguous().view(bsz_heads, length * 2 * length)
x = nn.functional.pad(x, (0, length - 1))
# Reshape and slice out the padded elements.
x = x.view(bsz_heads, length + 1, 2 * length - 1)
return x[:, :length, length - 1 :]
else:
x = nn.functional.pad(x, (1, 0))
x = x.transpose(0, 1)
x = x.contiguous().view(bsz_heads, length + 1, length)
return x[:, 1:, :]
class ResBlock(nn.Module):
def __init__(self, num_ins: int, num_outs: int, stride: int = 1, pre_activation: bool = False, beta: float = 1.0):
super().__init__()
self.conv1 = nn.Conv1d(num_ins, num_outs, 3, padding=1, stride=stride)
self.norm1 = nn.BatchNorm1d(num_outs)
self.conv2 = nn.Conv1d(num_outs, num_outs, 3, padding=1)
self.norm2 = nn.BatchNorm1d(num_outs)
# self.act = nn.ReLU()
self.act = nn.GELU() # TODO: test which is better
self.beta = beta
if stride != 1 or num_ins != num_outs:
self.residual_path = nn.Conv1d(num_ins, num_outs, 1, stride=stride)
self.res_norm = nn.BatchNorm1d(num_outs)
if pre_activation:
self.skip = nn.Sequential(self.res_norm, self.residual_path)
else:
self.skip = nn.Sequential(self.residual_path, self.res_norm)
else:
self.skip = nn.Identity()
# ResNet v2 style pre-activation https://arxiv.org/pdf/1603.05027.pdf
self.pre_activation = pre_activation
if pre_activation:
self.block = nn.Sequential(self.norm1, self.act, self.conv1, self.norm2, self.act, self.conv2)
else:
self.block = nn.Sequential(self.conv1, self.norm1, self.act, self.conv2, self.norm2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
res = self.block(x) * self.beta
x = self.skip(x)
if self.pre_activation:
return x + res
else:
return self.act(x + res)
class LRPEAttention(nn.Module):
"""
Multi Head Attention with Learned Relative Positional Encoding (LRPE) applied to the logits.
"""
def __init__(
self,
dim: int,
num_heads: int = 3,
qkv_bias: bool = True,
attn_drop: float = 0.1,
relative_positional_distance: int = 100,
):
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads, self.dim = num_heads, dim
self.hd = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.relative_positional = LearnedRelativePositionalEmbedding(
relative_positional_distance, num_heads, self.hd, True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Runs the multi-head self-attention layer.
Args:
x: the input to the layer, a tensor of shape [batch_size, length, d_model]
Returns:
A single tensor containing the output from this layer
"""
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.hd).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# [B, n_h, N, h_d]
scale_factor = 1 / math.sqrt(q.size(-1))
logits = q @ k.transpose(-2, -1) * scale_factor
# q shape: [B, n_h, N, h_d]
q_pos = q.permute(0, 2, 1, 3) # [B, N, n_h, h_d]
b, l, h, d = q_pos.size()
# The forward pass of relative_positional expects (length, batch*heads, embed_dim)
position_logits, _ = self.relative_positional(q_pos.reshape(l, b * h, d))
# position_logits is (b*h, l, l). We need to reshape to (b, h, l, l)
position_logits = position_logits.view(b, h, l, l)
logits = logits + position_logits
probs = F.softmax(logits, dim=-1)
probs = self.attn_drop(probs)
out = (probs @ v).transpose(1, 2).reshape(B, N, D)
out = self.proj(out)
return out
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
dropout: float = 0.1,
act_layer: nn.Module = nn.GELU,
):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(self.dropout(self.act(self.fc1(x))))
class CustomAttentionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
self.attn = LRPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
norm_layer=norm_layer,
relative_positional_distance=100,
)
self.norm1 = norm_layer(dim)
ffn_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=ffn_dim,
out_features=dim,
dropout=proj_drop,
act_layer=act_layer,
)
self.dropout1 = nn.Dropout(proj_drop)
self.dropout2 = nn.Dropout(proj_drop)
self.norm2 = norm_layer(dim)
self.activation = act_layer()
def forward(self, src: torch.Tensor) -> torch.Tensor:
src2 = self.attn(src)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.mlp(src)
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
class EMGTransformer(nn.Module):
def __init__(
self,
num_features: int,
num_outs: int,
num_aux_outs: int = None,
in_chans: int = 8,
embed_dim: int = 192,
n_layer: int = 8,
n_head: int = 3,
mlp_ratio: int = 4,
qkv_bias: bool = True,
attn_drop: float = 0.1,
proj_drop: float = 0.1,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
freeze_blocks: bool = False,
):
super().__init__()
self.in_chans = in_chans
self.n_layer = n_layer
self.n_head = n_head
self.embed_dim = embed_dim
self.conv_blocks = nn.Sequential(
ResBlock(in_chans, embed_dim, 2),
ResBlock(embed_dim, embed_dim, 2),
ResBlock(embed_dim, embed_dim, 2),
)
self.w_raw_in = nn.Linear(embed_dim, embed_dim)
self.blocks = nn.ModuleList(
[
CustomAttentionBlock(
dim=embed_dim,
num_heads=n_head,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
act_layer=act_layer,
norm_layer=norm_layer,
)
for _ in range(n_layer)
]
)
self.w_out = nn.Linear(embed_dim, num_outs)
self.has_aux_out = num_aux_outs is not None
if self.has_aux_out:
self.w_aux = nn.Linear(embed_dim, num_aux_outs)
# ----------------------------------------------
self.initialize_weights()
# Freeze multi-head attention blocks
if freeze_blocks:
for param in self.blocks.parameters():
param.requires_grad = False
def initialize_weights(self):
"""Initializes the model weights."""
# Encodings Initializations code taken from the LaBraM paper
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x_feat: torch.Tensor, x_raw: torch.Tensor, session_ids: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
x_feat and session_ids are kept for compatibility but unused.
"""
# x shape is (batch, time, electrode)
if self.training:
r = random.randrange(8)
if r > 0:
x_raw[:, :-r, :] = x_raw[:, r:, :].clone() # shift left r
x_raw[:, -r:, :] = 0
x_raw = x_raw.transpose(1, 2) # put channel before time for conv
x_raw = self.conv_blocks(x_raw) # N B D
x_raw = x_raw.transpose(1, 2) # B N D
x_raw = self.w_raw_in(x_raw) # B N D
x = x_raw
for blk in self.blocks:
x = blk(x)
if self.has_aux_out:
return self.w_out(x), self.w_aux(x)
return self.w_out(x)