Skip to content

Commit f69c775

Browse files
dskhudiafacebook-github-bot
authored andcommitted
Improve im2col for certain cases (pytorch#715)
Summary: Pull Request resolved: pytorch#715 Copy across pixels along width dim for im2col. This should help convolution with small number of input channels. Copy across pixels of input width if we can. We can only do this if the following conditions are met. 1) If the number of groups is 1. For number of groups > 1, im2col doesn't copy data across groups. 2) If dilation is 1. For dilation > 1, copying from input across channels is not sequential. 3) For copy from the last channel (end of filter or end of image width) for the current filter, only copy if we have enough in the current channel. Differential Revision: D31227743 fbshipit-source-id: eadc62308f221f8731f566cd35eaeccd40781e6a
1 parent 7bee788 commit f69c775

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

src/PackAWithIm2Col.cc

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -573,14 +573,33 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
573573
a_zero_pt_,
574574
sizeof(T) * (j_blk_end - j_blk_start));
575575
} else {
576+
int chn_start_idx = j_blk_start % ic_per_group;
577+
int src_offset =
578+
((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] + w_in) *
579+
conv_p_.IC + g * ic_per_group + chn_start_idx;
580+
// fast path
581+
// Copy across pixels of input width if we can. We can only do this
582+
// if the following conditions are met. 1) If the number of groups
583+
// is 1. For number of groups > 1, im2col
584+
// doesn't copy data across groups.
585+
// 2) If dilation is 1. For dilation > 1, copying from input
586+
// across channels is not sequential.
587+
// 3) For copy from the last channel (end of filter or
588+
// end of image width) for the current filter,
589+
// only copy if we have enough in the current channel.
590+
//
591+
if (conv_p_.G == 1 && conv_p_.dilation[1] == 1 &&
592+
((s < (conv_p_.K[1] - 1) && w_in < (conv_p_.IN_DIM[1] - 1)) ||
593+
((chn_start_idx + block.col_size) <= ic_per_group))) {
594+
j_blk_end = std::min(
595+
(j_blk_id + conv_p_.K[1]) * ic_per_group,
596+
block.col_start + block.col_size);
597+
j += ic_per_group * (conv_p_.K[1] - 1);
598+
}
576599
std::memcpy(
577600
out + (i - block.row_start) * BaseType::blockColSize() +
578601
j_blk_start - block.col_start,
579-
sdata_ +
580-
((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] +
581-
w_in) *
582-
conv_p_.IC +
583-
g * ic_per_group + (j_blk_start % ic_per_group),
602+
sdata_ + src_offset,
584603
sizeof(T) * (j_blk_end - j_blk_start));
585604
}
586605
}

0 commit comments

Comments
 (0)