Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit ad01b81

Browse files
rajprateekfacebook-github-bot
authored andcommitted
Add prec_time to test_net.py and activations count
Summary: - This diff introduces changes to run prec_time computation in test_net as well. - We don't currently have the code to count activations in pycls. This will be useful for users to verify the activations of the upcoming open-sourced models. Reviewed By: pdollar Differential Revision: D20971486 fbshipit-source-id: b2319e8fde9e2d862f4fc020bd642b31e498c453
1 parent 4e470e2 commit ad01b81

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

pycls/utils/metrics.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,29 @@ def flops_count(model):
7272
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
7373
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
7474
elif isinstance(m, nn.Linear):
75-
count += m.in_features * m.out_features
75+
count += m.in_features * m.out_features + m.bias.numel()
76+
return count.item()
77+
78+
79+
def acts_count(model):
80+
"""Computes the number of activations statically."""
81+
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
82+
count = 0
83+
for n, m in model.named_modules():
84+
if isinstance(m, nn.Conv2d):
85+
if "se." in n:
86+
count += m.out_channels
87+
continue
88+
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
89+
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
90+
count += np.prod([m.out_channels, h_out, w_out])
91+
if ".proj" not in n:
92+
h, w = h_out, w_out
93+
elif isinstance(m, nn.MaxPool2d):
94+
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
95+
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
96+
elif isinstance(m, nn.Linear):
97+
count += m.out_features
7698
return count.item()
7799

78100

tools/test_net.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
import sys
1212

1313
import numpy as np
14+
import pycls.core.losses as losses
1415
import pycls.core.model_builder as model_builder
1516
import pycls.datasets.loader as loader
17+
import pycls.utils.benchmark as bu
1618
import pycls.utils.checkpoint as cu
1719
import pycls.utils.distributed as du
1820
import pycls.utils.logging as lu
1921
import pycls.utils.metrics as mu
2022
import pycls.utils.multiprocessing as mpu
23+
import pycls.utils.net as nu
2124
import torch
2225
from pycls.core.config import assert_and_infer_cfg, cfg
2326
from pycls.utils.meters import TestMeter
@@ -49,6 +52,7 @@ def log_model_info(model):
4952
logger.info("Model:\n{}".format(model))
5053
logger.info("Params: {:,}".format(mu.params_count(model)))
5154
logger.info("Flops: {:,}".format(mu.flops_count(model)))
55+
logger.info("Acts: {:,}".format(mu.acts_count(model)))
5256

5357

5458
@torch.no_grad()
@@ -89,6 +93,13 @@ def test_model():
8993
model = model_builder.build_model()
9094
log_model_info(model)
9195

96+
# Compute precise time
97+
if cfg.PREC_TIME.ENABLED:
98+
logger.info("Computing precise time...")
99+
loss_fun = losses.get_loss_fun()
100+
bu.compute_precise_time(model, loss_fun)
101+
nu.reset_bn_stats(model)
102+
92103
# Load model weights
93104
cu.load_checkpoint(cfg.TEST.WEIGHTS, model)
94105
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))

tools/train_net.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def log_model_info(model):
6161
logger.info("Model:\n{}".format(model))
6262
logger.info("Params: {:,}".format(mu.params_count(model)))
6363
logger.info("Flops: {:,}".format(mu.flops_count(model)))
64+
logger.info("Acts: {:,}".format(mu.acts_count(model)))
6465

6566

6667
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):

0 commit comments

Comments
 (0)