Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit ec18863

Browse files
rajprateekfacebook-github-bot
authored andcommitted
Add SE to AnyNet
Summary: We are going to open-source RegNet models. RegNet-Y models use SE, so we're adding this into anynet.py to allow these models to be constructed. Reviewed By: pdollar Differential Revision: D20950208 fbshipit-source-id: a52946e0d51ddd8c9bacfc268ba5b7871a4e890a
1 parent 8817629 commit ec18863

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

pycls/core/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@
8888
# Group param for each stage (number of groups or group width)
8989
_C.ANYNET.GROUPS = []
9090

91+
# Whether SE is enabled for res_bottleneck_block
92+
_C.ANYNET.SE_ENABLED = False
93+
94+
# SE ratio
95+
_C.ANYNET.SE_RATIO = 0.25
9196

9297
# ---------------------------------------------------------------------------- #
9398
# EfficientNet options

pycls/models/anynet.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,28 @@ def get_block_fun(block_type):
4242
return block_funs[block_type]
4343

4444

45+
class SE(nn.Module):
46+
"""Squeeze-and-Excitation (SE) block"""
47+
48+
def __init__(self, dim_in, dim_se):
49+
super(SE, self).__init__()
50+
self._construct(dim_in, dim_se)
51+
52+
def _construct(self, dim_in, dim_se):
53+
# AvgPool
54+
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
55+
# FC, Activation, FC, Sigmoid
56+
self.f_ex = nn.Sequential(
57+
nn.Conv2d(dim_in, dim_se, kernel_size=1, bias=True),
58+
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
59+
nn.Conv2d(dim_se, dim_in, kernel_size=1, bias=True),
60+
nn.Sigmoid(),
61+
)
62+
63+
def forward(self, x):
64+
return x * self.f_ex(self.avg_pool(x))
65+
66+
4567
class AnyHead(nn.Module):
4668
"""AnyNet head."""
4769

@@ -165,6 +187,12 @@ def _construct(self, w_in, w_out, stride, bm, g):
165187
)
166188
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
167189
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
190+
191+
if cfg.ANYNET.SE_ENABLED:
192+
se_r = cfg.ANYNET.SE_RATIO
193+
dim_se = int(round(w_in * se_r))
194+
self.se = SE(w_b, dim_se)
195+
168196
# 1x1, BN
169197
self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False)
170198
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)

0 commit comments

Comments
 (0)