aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/reference_util.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-02-07 11:19:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-07 11:28:12 -0800
commitf89185b655d380074c3e1e932e90d80cd2b01241 (patch)
tree726cffe7d575904b3928e50fb2b6c2f01f9d2b88 /tensorflow/compiler/xla/reference_util.cc
parent7135d08d4e6067865d7b5f2907013c960a12ae4f (diff)
[XLA] Avoid half of the idivs in ReferenceUtil::ConvArray4DGeneralDimensionsDilated.
It's trivial to avoid half of the idivs in this function; they're just loop induction variables. Change: 146809277
Diffstat (limited to 'tensorflow/compiler/xla/reference_util.cc')
-rw-r--r--tensorflow/compiler/xla/reference_util.cc30
1 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index bff59454e7..d8ed1ee0f7 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -370,17 +370,22 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
return lhs(index[0], index[1], index[2], index[3]);
};
- // Lambda to access the rhs operand at the given 4D index.
- const auto rhs_element = [&](int64 kernel_output_feature,
- int64 kernel_input_feature, int64 height,
- int64 width) {
- CHECK_EQ(fast_imod64(height, dky), 0);
- CHECK_EQ(fast_imod64(width, dkx), 0);
+ // Lambda to access the rhs operand at the given 4D index. height_over_dky
+ // should be equal to height / dky, and width_over_dkx should be equal to
+ // width / dkx. (This is an optimization to avoid doing divisions.)
+ const auto rhs_element = [&](
+ int64 kernel_output_feature, int64 kernel_input_feature, int64 height,
+ int64 width, int64 height_over_dky, int64 width_over_dkx) {
+ DCHECK_EQ(height % dky, 0);
+ DCHECK_EQ(width % dkx, 0);
+ DCHECK_EQ(height / dky, height_over_dky);
+ DCHECK_EQ(width / dkx, width_over_dkx);
+
std::array<int64, 4> index;
index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
- index[dnums.kernel_spatial_dimensions(0)] = fast_idiv64(height, dky);
- index[dnums.kernel_spatial_dimensions(1)] = fast_idiv64(width, dkx);
+ index[dnums.kernel_spatial_dimensions(0)] = height_over_dky;
+ index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx;
return rhs(index[0], index[1], index[2], index[3]);
};
@@ -400,14 +405,17 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
for (int64 sample = 0; sample < samples; ++sample) {
for (int64 izi = 0; izi < iz; ++izi) {
for (int64 ozi = 0; ozi < oz; ++ozi) {
- for (int64 kyi = 0; kyi < ky; kyi += dky) {
- for (int64 kxi = 0; kxi < kx; kxi += dkx) {
+ for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky;
+ kyi += dky, kyi_over_dky++) {
+ for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx;
+ kxi += dkx, kxi_over_dkx++) {
int64 iyi = istarty + ksy * oyi + kyi;
int64 ixi = istartx + ksx * oxi + kxi;
float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
? 0.0
: lhs_element(sample, izi, iyi, ixi);
- float gain = rhs_element(ozi, izi, kyi, kxi);
+ float gain =
+ rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx);
float addend = input * gain;
result_element(sample, ozi, oyi, oxi) += addend;
}