-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathmodels.py
More file actions
154 lines (127 loc) · 6.31 KB
/
models.py
File metadata and controls
154 lines (127 loc) · 6.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torch.nn as nn
from functools import partial
from vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, to_2tuple
import torch.nn.functional as F
import math
__all__ = ['deit_small_MCTformerPlus']
class MCTformerPlus(VisionTransformer):
def __init__(self, decay_parameter=0.996, input_size=244, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head = nn.Conv2d(self.embed_dim, self.num_classes, kernel_size=3, stride=1, padding=1)
self.head.apply(self._init_weights)
img_size = to_2tuple(input_size)
patch_size = to_2tuple(self.patch_embed.patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.num_patches = num_patches
self.cls_token = nn.Parameter(torch.zeros(1, self.num_classes, self.embed_dim))
self.pos_embed_cls = nn.Parameter(torch.zeros(1, self.num_classes, self.embed_dim))
self.pos_embed_pat = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.pos_embed_cls, std=.02)
trunc_normal_(self.pos_embed_pat, std=.02)
print(self.training)
self.decay_parameter=decay_parameter
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - self.num_classes
N = self.num_patches
if npatch == N and w == h:
return self.pos_embed_pat
patch_pos_embed = self.pos_embed_pat
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size[0]
h0 = h // self.patch_embed.patch_size[0]
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward_features(self, x, n=12):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if not self.training:
pos_embed_pat = self.interpolate_pos_encoding(x, w, h)
x = x + pos_embed_pat
else:
x = x + self.pos_embed_pat
cls_tokens = self.cls_token.expand(B, -1, -1)
cls_tokens = cls_tokens + self.pos_embed_cls
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x)
attn_weights = []
class_embeddings = []
for i, blk in enumerate(self.blocks):
x, weights_i = blk(x)
attn_weights.append(weights_i)
class_embeddings.append(x[:, 0:self.num_classes])
return x[:, 0:self.num_classes], x[:, self.num_classes:], attn_weights, class_embeddings
def forward(self, x, return_att=False, n_layers=12, attention_type='fused'):
w, h = x.shape[2:]
x_cls, x_patch, attn_weights, all_x_cls = self.forward_features(x)
n, p, c = x_patch.shape
if w != h:
w0 = w // self.patch_embed.patch_size[0]
h0 = h // self.patch_embed.patch_size[0]
x_patch = torch.reshape(x_patch, [n, w0, h0, c])
else:
x_patch = torch.reshape(x_patch, [n, int(p ** 0.5), int(p ** 0.5), c])
x_patch = x_patch.permute([0, 3, 1, 2])
x_patch = x_patch.contiguous()
x_patch = self.head(x_patch)
x_patch_flattened = x_patch.view(x_patch.shape[0], x_patch.shape[1], -1).permute(0, 2, 1)
sorted_patch_token, indices = torch.sort(x_patch_flattened, -2, descending=True)
weights = torch.logspace(start=0, end=x_patch_flattened.size(-2) - 1,
steps=x_patch_flattened.size(-2), base=self.decay_parameter).cuda()
x_patch_logits = torch.sum(sorted_patch_token * weights.unsqueeze(0).unsqueeze(-1), dim=-2) / weights.sum()
x_cls_logits = x_cls.mean(-1)
output = []
output.append(x_cls_logits)
output.append(torch.stack(all_x_cls))
output.append(x_patch_logits)
if return_att:
feature_map = x_patch.detach().clone() # B * C * 14 * 14
feature_map = F.relu(feature_map)
n, c, h, w = feature_map.shape
attn_weights = torch.stack(attn_weights) # 12 * B * H * N * N
attn_weights = torch.mean(attn_weights, dim=2) # 12 * B * N * N
mtatt = attn_weights[-n_layers:].mean(0)[:, 0:self.num_classes, self.num_classes:].reshape([n, c, h, w])
patch_attn = attn_weights[:, :, self.num_classes:, self.num_classes:]
if attention_type == 'fused':
cams = mtatt * feature_map # B * C * 14 * 14
cams = torch.sqrt(cams)
elif attention_type == 'patchcam':
cams = feature_map
elif attention_type == 'mct':
cams = mtatt
else:
raise f'Error! {attention_type} is not defined!'
x_logits = (x_cls_logits + x_patch_logits) / 2
return x_logits, cams, patch_attn
else:
return output
@register_model
def deit_small_MCTformerPlus(pretrained=False, **kwargs):
model = MCTformerPlus(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
map_location="cpu", check_hash=True
)['model']
model_dict = model.state_dict()
for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
if k in checkpoint and checkpoint[k].shape != model_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint[k]
pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in ['cls_token', 'pos_embed']}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model