diff options
-rw-r--r-- | tensorflow/compiler/xla/reference_util.cc | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index bff59454e7..d8ed1ee0f7 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -370,17 +370,22 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( return lhs(index[0], index[1], index[2], index[3]); }; - // Lambda to access the rhs operand at the given 4D index. - const auto rhs_element = [&](int64 kernel_output_feature, - int64 kernel_input_feature, int64 height, - int64 width) { - CHECK_EQ(fast_imod64(height, dky), 0); - CHECK_EQ(fast_imod64(width, dkx), 0); + // Lambda to access the rhs operand at the given 4D index. height_over_dky + // should be equal to height / dky, and width_over_dkx should be equal to + // width / dkx. (This is an optimization to avoid doing divisions.) + const auto rhs_element = [&]( + int64 kernel_output_feature, int64 kernel_input_feature, int64 height, + int64 width, int64 height_over_dky, int64 width_over_dkx) { + DCHECK_EQ(height % dky, 0); + DCHECK_EQ(width % dkx, 0); + DCHECK_EQ(height / dky, height_over_dky); + DCHECK_EQ(width / dkx, width_over_dkx); + std::array<int64, 4> index; index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; - index[dnums.kernel_spatial_dimensions(0)] = fast_idiv64(height, dky); - index[dnums.kernel_spatial_dimensions(1)] = fast_idiv64(width, dkx); + index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; + index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; return rhs(index[0], index[1], index[2], index[3]); }; @@ -400,14 +405,17 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( for (int64 sample = 0; sample < samples; ++sample) { for (int64 izi = 0; izi < iz; ++izi) { for (int64 ozi = 0; ozi < oz; ++ozi) { - for (int64 kyi = 0; kyi < ky; kyi += dky) { - for (int64 kxi = 0; kxi < kx; kxi += dkx) { + for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky; + kyi += dky, kyi_over_dky++) { + for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx; + kxi += dkx, kxi_over_dkx++) { int64 iyi = istarty + ksy * oyi + kyi; int64 ixi = istartx + ksx * oxi + kxi; float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0) ? 0.0 : lhs_element(sample, izi, iyi, ixi); - float gain = rhs_element(ozi, izi, kyi, kxi); + float gain = + rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx); float addend = input * gain; result_element(sample, ozi, oyi, oxi) += addend; } |