aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2017-07-27 17:59:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-27 18:03:18 -0700
commit253bcbb71bdd1f9f2609b085dce90fe9b31cbd5a (patch)
tree691994a59c0ed9618f2764919ca8dbcc2dc3284f /tensorflow
parent569a00e681f9d73f820f96e88632dcc034f5a757 (diff)
[XLA] Use HloEvaluator for convolution in reference_util.
Also Speed up HloEvaluator's HandleConvolution in non-opt build, by moving calls to HloInstruction::shape() out of the inner loop. PiperOrigin-RevId: 163416183
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/BUILD3
-rw-r--r--tensorflow/compiler/xla/literal_util.h19
-rw-r--r--tensorflow/compiler/xla/reference_util.cc247
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc160
-rw-r--r--tensorflow/compiler/xla/shape_util.cc30
-rw-r--r--tensorflow/compiler/xla/shape_util.h31
6 files changed, 201 insertions, 289 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index e0a03a78f1..ba90b13b38 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -563,6 +563,9 @@ cc_library(
":xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_evaluator",
+ "//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/core:lib",
],
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 125c268573..e02a96ae70 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -470,10 +470,11 @@ class Literal {
// Populates literal values by calling the generator function for every cell
// in this literal object.
- template <typename NativeT>
- Status Populate(
- const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
- generator);
+ //
+ // generator must be a callable of the type
+ // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+ template <typename NativeT, typename FnType>
+ Status Populate(const FnType& generator);
// Creates a Literal of the given dimensions with all elements set to the
// given value.
@@ -1107,12 +1108,10 @@ void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
}
-template <typename NativeT>
-Status Literal::Populate(
- const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
- generator) {
+template <typename NativeT, typename FnType>
+Status Literal::Populate(const FnType& generator) {
const Shape& this_shape = shape();
- int64 rank = ShapeUtil::Rank(this_shape);
+ const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(this_shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());
tensorflow::gtl::MutableArraySlice<NativeT> data =
@@ -1125,7 +1124,7 @@ Status Literal::Populate(
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
auto init_function = [&](const std::vector<int64>& indexes) {
- int64 index = LinearIndex(indexes);
+ const int64 index = LinearIndex(indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
for (int64 i = 0; i < minor_dimension_size; ++i) {
minor_scan_indexes[stride_config.minor_dimension] = i;
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 7ef5c6d916..64e197fac5 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -20,6 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/math/math_util.h"
@@ -446,179 +449,85 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
std::pair<int64, int64> kernel_stride, Padding padding,
std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
ConvolutionDimensionNumbers dnums) {
- std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}};
- std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}};
-
- const int64 ksy = kernel_stride.first;
- const int64 ksx = kernel_stride.second;
- const int64 dy = lhs_dilation.first;
- const int64 dx = lhs_dilation.second;
- const int64 dky = rhs_dilation.first;
- const int64 dkx = rhs_dilation.second;
- CHECK_GE(dky, 1);
- CHECK_GE(dkx, 1);
- CHECK_GE(dy, 1);
- CHECK_GE(dx, 1);
-
- // Get all dimension sizes in lhs and rhs based on the given convolution
- // dimension configuration.
- const int64 ix = window_util::DilatedBound(
- lhs_dimensions[dnums.spatial_dimensions(1)], dx);
- const int64 iy = window_util::DilatedBound(
- lhs_dimensions[dnums.spatial_dimensions(0)], dy);
- const int64 iz = lhs_dimensions[dnums.feature_dimension()];
- const int64 samples = lhs_dimensions[dnums.batch_dimension()];
- const int64 kx = window_util::DilatedBound(
- rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx);
- const int64 ky = window_util::DilatedBound(
- rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky);
- const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()];
- {
- const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()];
- CHECK_EQ(kiz, iz);
+ HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
+ auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
+ auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
+
+ std::array<int64, 2> ordered_kernel_strides;
+ std::array<int64, 2> ordered_input_dimensions;
+ std::array<int64, 2> ordered_kernel_dimensions;
+ if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
+ ordered_kernel_strides[0] = kernel_stride.second;
+ ordered_kernel_strides[1] = kernel_stride.first;
+ } else {
+ ordered_kernel_strides[0] = kernel_stride.first;
+ ordered_kernel_strides[1] = kernel_stride.second;
}
- if (padding == Padding::kSame) {
- // We reject same padding with kernel striding, since it's somewhat
- // nonsensical. We can always follow up to implement this with the desired
- // semantics if anybody actually uses it.
- CHECK_EQ(1, ksy);
- CHECK_EQ(1, ksx);
- }
-
- const int64 ox =
- padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx);
- const int64 oy =
- padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy);
- const int64 istartx =
- padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2;
- const int64 istarty =
- padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2;
- // Create the output result array and reset the values to 0.
- std::array<int64, 4> result_dimensions;
- result_dimensions[dnums.batch_dimension()] = samples;
- result_dimensions[dnums.feature_dimension()] = oz;
- result_dimensions[dnums.spatial_dimensions(0)] = oy;
- result_dimensions[dnums.spatial_dimensions(1)] = ox;
+ ordered_input_dimensions[0] =
+ lhs_literal->shape().dimensions(dnums.spatial_dimensions(0));
+ ordered_input_dimensions[1] =
+ lhs_literal->shape().dimensions(dnums.spatial_dimensions(1));
+ ordered_kernel_dimensions[0] =
+ rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
+ ordered_kernel_dimensions[1] =
+ rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
+
+ std::vector<std::pair<int64, int64>> paddings =
+ MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
+ ordered_kernel_strides, padding);
+ CHECK_EQ(paddings.size(), 2);
+
+ Window window;
+
+ WindowDimension dim;
+ dim.set_size(
+ rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
+ dim.set_stride(kernel_stride.first);
+ dim.set_padding_low(paddings[0].first);
+ dim.set_padding_high(paddings[0].second);
+ dim.set_window_dilation(rhs_dilation.first);
+ dim.set_base_dilation(lhs_dilation.first);
+ *window.add_dimensions() = dim;
+
+ WindowDimension dim2;
+ dim2.set_size(
+ rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
+ dim2.set_stride(kernel_stride.second);
+ dim2.set_padding_low(paddings[1].first);
+ dim2.set_padding_high(paddings[1].second);
+ dim2.set_window_dilation(rhs_dilation.second);
+ dim2.set_base_dilation(lhs_dilation.second);
+ *window.add_dimensions() = dim2;
+
+ const Shape& shape =
+ ShapeInference::InferConvolveShape(lhs_literal->shape(),
+ rhs_literal->shape(), window, dnums)
+ .ConsumeValueOrDie();
+
+ HloInstruction* lhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
+ HloInstruction* rhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
+
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ shape, lhs_instruction, rhs_instruction, window, dnums));
+
+ HloEvaluator evaluator;
+ std::unique_ptr<Literal> result_literal =
+ evaluator.Evaluate(*b.Build(), {}).ConsumeValueOrDie();
+
+ CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
auto result =
- MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1],
- result_dimensions[2], result_dimensions[3]);
- result->Fill(0.0);
-
- const auto is_int32 = [](int64 x) {
- return x >= std::numeric_limits<int32>::min() &&
- x <= std::numeric_limits<int32>::max();
- };
-
- // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at
- // least on x86-64), so we avoid them where possible.
- const auto fast_idiv64 = [&](int64 a, int64 b) {
- if (is_int32(a) && is_int32(b)) {
- return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b));
- }
- return a / b;
- };
- const auto fast_imod64 = [&](int64 a, int64 b) {
- if (is_int32(a) && is_int32(b)) {
- return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b));
- }
- return a % b;
- };
-
- // Lambda to access the lhs operand at the given 4D index.
- const auto lhs_element = [&](int64 batch, int64 feature, int64 height,
- int64 width) {
- if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) {
- return 0.0f;
- }
+ MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
+ result_literal->shape().dimensions(1),
+ result_literal->shape().dimensions(2),
+ result_literal->shape().dimensions(3));
+
+ result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
+ *value = result_literal->Get<float>(indices);
+ });
- std::array<int64, 4> index;
- index[dnums.batch_dimension()] = batch;
- index[dnums.feature_dimension()] = feature;
- index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy);
- index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx);
- return lhs(index[0], index[1], index[2], index[3]);
- };
-
- // Lambda to access the rhs operand at the given 4D index. height_over_dky
- // should be equal to height / dky, and width_over_dkx should be equal to
- // width / dkx. (This is an optimization to avoid doing divisions.)
- const auto rhs_element =
- [&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height,
- int64 width, int64 height_over_dky, int64 width_over_dkx) {
- DCHECK_EQ(height % dky, 0);
- DCHECK_EQ(width % dkx, 0);
- DCHECK_EQ(height / dky, height_over_dky);
- DCHECK_EQ(width / dkx, width_over_dkx);
-
- std::array<int64, 4> index;
- index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
- index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
- index[dnums.kernel_spatial_dimensions(0)] = height_over_dky;
- index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx;
- return rhs(index[0], index[1], index[2], index[3]);
- };
-
- // Lambda to access the result data at the given 4D index.
- const auto result_element = [&](int64 batch, int64 kernel_output_feature,
- int64 height, int64 width) -> float& {
- std::array<int64, 4> index;
- index[dnums.batch_dimension()] = batch;
- index[dnums.feature_dimension()] = kernel_output_feature;
- index[dnums.spatial_dimensions(0)] = height;
- index[dnums.spatial_dimensions(1)] = width;
- return (*result)(index[0], index[1], index[2], index[3]);
- };
-
- for (int64 oyi = 0; oyi < oy; ++oyi) {
- for (int64 oxi = 0; oxi < ox; ++oxi) {
- for (int64 sample = 0; sample < samples; ++sample) {
- for (int64 izi = 0; izi < iz; ++izi) {
- for (int64 ozi = 0; ozi < oz; ++ozi) {
- for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky;
- kyi += dky, kyi_over_dky++) {
- for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx;
- kxi += dkx, kxi_over_dkx++) {
- int64 iyi = istarty + ksy * oyi + kyi;
- int64 ixi = istartx + ksx * oxi + kxi;
- float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
- ? 0.0
- : lhs_element(sample, izi, iyi, ixi);
- float gain =
- rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx);
- float addend = input * gain;
- result_element(sample, ozi, oyi, oxi) += addend;
- }
- }
- }
- }
- }
- }
- }
- if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 ||
- iz == 0) {
- LOG(INFO) << "Output will be trivially empty because one of these "
- "dimensions is 0: samples: "
- << samples << " kx: " << kx << " ky: " << ky << " ox: " << ox
- << " oy: " << oy << " oz: " << oz << " iz: " << iz;
- return result;
- }
- bool trivial = true;
- auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice<int64> indices,
- float value) {
- if (value != 0.0) {
- trivial = false;
- }
- };
- lhs.Each(check_trivial);
- if (trivial) {
- LOG(FATAL) << "LHS is all 0.0.";
- }
- trivial = true;
- rhs.Each(check_trivial);
- if (trivial) {
- LOG(FATAL) << "RHS is all 0.0.";
- }
return result;
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index e2a807595b..55f5504de4 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -404,12 +404,16 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
HloInstruction* rhs, const Window& window) override {
- CHECK(ShapeUtil::IsArray(lhs->shape()));
- CHECK(ShapeUtil::IsArray(rhs->shape()));
- CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
- CHECK(ShapeUtil::SameElementType(lhs->shape(), conv->shape()));
- TF_CHECK_OK(ShapeUtil::ValidateShape(lhs->shape()));
- TF_CHECK_OK(ShapeUtil::ValidateShape(rhs->shape()));
+ const Shape& result_shape = conv->shape();
+ const Shape& lhs_shape = lhs->shape();
+ const Shape& rhs_shape = rhs->shape();
+
+ TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape));
+ TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape));
+ CHECK(ShapeUtil::IsArray(lhs_shape));
+ CHECK(ShapeUtil::IsArray(rhs_shape));
+ CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape));
+ CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
const auto& dnums = conv->convolution_dimension_numbers();
const int64 num_spatial_dims = dnums.spatial_dimensions_size();
@@ -417,23 +421,23 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
CHECK_GE(num_spatial_dims, 1);
CHECK_EQ(window.dimensions_size(), num_spatial_dims);
- CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(lhs->shape()));
- CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(rhs->shape()));
+ const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
+ const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
+
+ CHECK_EQ(num_spatial_dims + 2, lhs_rank);
+ CHECK_EQ(num_spatial_dims + 2, rhs_rank);
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferConvolveShape(
- lhs->shape(), rhs->shape(), window, dnums));
- CHECK(ShapeUtil::Compatible(conv->shape(), inferred_return_shape))
- << "return shape set to: " << ShapeUtil::HumanString(conv->shape())
+ ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
+ window, dnums));
+ CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
+ << "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
- const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
-
// Dimension number applicable for both input (lhs), and output.
const int64 batch_dim = dnums.batch_dimension();
const int64 z_dim = dnums.feature_dimension();
@@ -441,78 +445,78 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
- const int64 z_size = ShapeUtil::GetDimension(lhs->shape(), z_dim);
+ const int64 z_size = ShapeUtil::GetDimension(lhs_shape, z_dim);
std::vector<int64> window_dimension_sizes;
for (auto i : dnums.kernel_spatial_dimensions()) {
- window_dimension_sizes.push_back(
- ShapeUtil::GetDimension(rhs->shape(), i));
+ window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
}
- const Shape& window_shape = ShapeUtil::MakeShape(
- rhs->shape().element_type(), window_dimension_sizes);
+ const Shape& window_shape =
+ ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
+
+ DimensionVector lhs_index(lhs_rank);
+ DimensionVector rhs_index(rhs_rank);
+ DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
+
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ ReturnT result_val = static_cast<ReturnT>(0);
+
+ std::fill(lhs_index.begin(), lhs_index.end(), 0);
+ std::fill(rhs_index.begin(), rhs_index.end(), 0);
+ std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
+
+ lhs_index[batch_dim] = out_index[batch_dim];
+ rhs_index[kernel_output_z_dim] = out_index[z_dim];
+
+ // Convolve input feature with kernel.
+ do {
+ for (int64 iz = 0; iz < z_size; ++iz) {
+ lhs_index[z_dim] = iz;
+ rhs_index[kernel_input_z_dim] = iz;
+
+ // Find corresponding spatial dimension index for input (lhs).
+ for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
+ // Spatial dimension number for input (lhs) and output.
+ const int64 spatial_dim = dnums.spatial_dimensions(ki);
+
+ // Calculate lhs (input) index without taking base dilation into
+ // account.
+ const auto& window_dim = window.dimensions(ki);
+ const int64 undilated_index =
+ out_index[spatial_dim] * window_dim.stride() -
+ window_dim.padding_low() +
+ rhs_spatial_index[ki] * window_dim.window_dilation();
+ // Skip if the lhs (input) index is to be dilated.
+ if (undilated_index % window_dim.base_dilation() != 0) {
+ goto cnt;
+ }
- auto result = Literal::CreateFromShape(conv->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> out_index) {
- ReturnT result_val = static_cast<ReturnT>(0);
+ // Calculate the actual lhs (input) index after dilation.
+ lhs_index[spatial_dim] =
+ undilated_index / window_dim.base_dilation();
- std::vector<int64> lhs_index(lhs_rank, 0);
- std::vector<int64> rhs_index(rhs_rank, 0);
-
- lhs_index[batch_dim] = out_index[batch_dim];
- rhs_index[kernel_output_z_dim] = out_index[z_dim];
-
- std::vector<int64> rhs_spatial_index(
- dnums.kernel_spatial_dimensions_size(), 0);
-
- // Convolve input feature with kernel.
- do {
- for (int64 iz = 0; iz < z_size; ++iz) {
- lhs_index[z_dim] = iz;
- rhs_index[kernel_input_z_dim] = iz;
-
- // Find corresponding spatial dimension index for input (lhs).
- for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
- // Spatial dimension number for input (lhs) and output.
- const int64 spatial_dim = dnums.spatial_dimensions(ki);
-
- // Calculate lhs (input) index without taking base dilation into
- // account.
- const int64 undilated_index =
- out_index[spatial_dim] * window.dimensions(ki).stride() -
- window.dimensions(ki).padding_low() +
- rhs_spatial_index[ki] *
- window.dimensions(ki).window_dilation();
- // Skip if the lhs (input) index is to be dilated.
- if (undilated_index % window.dimensions(ki).base_dilation() !=
- 0) {
- goto cnt;
- }
-
- // Calculate the actual lhs (input) index after dilation.
- lhs_index[spatial_dim] =
- undilated_index / window.dimensions(ki).base_dilation();
-
- // Skip if input index is not in bound.
- if (!(lhs_index[spatial_dim] >= 0 &&
- lhs_index[spatial_dim] <
- lhs->shape().dimensions(spatial_dim))) {
- goto cnt;
- }
-
- rhs_index[dnums.kernel_spatial_dimensions(ki)] =
- rhs_spatial_index[ki];
- }
-
- result_val += lhs_literal.Get<ReturnT>(lhs_index) *
- rhs_literal.Get<ReturnT>(rhs_index);
+ // Skip if input index is not in bound.
+ if (!(lhs_index[spatial_dim] >= 0 &&
+ lhs_index[spatial_dim] < lhs_shape.dimensions(spatial_dim))) {
+ goto cnt;
}
- cnt:;
- } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
- return result_val;
- }));
+ rhs_index[dnums.kernel_spatial_dimensions(ki)] =
+ rhs_spatial_index[ki];
+ }
+
+ result_val += lhs_literal.Get<ReturnT>(lhs_index) *
+ rhs_literal.Get<ReturnT>(rhs_index);
+ }
+ cnt:;
+ } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
+
+ return result_val;
+ };
+
+ auto result = Literal::CreateFromShape(result_shape);
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 4c079c87d4..745491e485 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1218,34 +1218,4 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
return shape;
}
-/* static */ void ShapeUtil::ForEachIndex(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
- const IndexVisitorFunction& visitor_function) {
- if (ShapeUtil::HasZeroElements(shape)) {
- return;
- }
- DCHECK_EQ(Rank(shape), base.size());
- DCHECK_EQ(incr.size(), base.size());
- DCHECK_EQ(count.size(), base.size());
- const Layout& layout = shape.layout();
- int64 rank = layout.minor_to_major_size();
- // Allows handling R0 arrays, such that the visitor function will be called
- // once with the proper empty indexes.
- int64 n = -1;
- std::vector<int64> indexes(base.begin(), base.end());
- while (n < rank && visitor_function(indexes)) {
- // Increments dimensions in minor to major order.
- for (n = 0; n < rank; ++n) {
- int64 dim = layout.minor_to_major(n);
- indexes[dim] += incr[dim];
- if (indexes[dim] < base[dim] + count[dim]) {
- break;
- }
- indexes[dim] = base[dim];
- }
- }
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index fd4adbf34c..b79b88581f 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -421,12 +421,39 @@ class ShapeUtil {
// current index.
// The visitor_function visitor function should return true if it wants to
// continue, or false otherwise.
- using IndexVisitorFunction = std::function<bool(const std::vector<int64>&)>;
+ //
+ // visitor_function must be a callable of type bool(const std::vector<int64>&)
+ // or compatible.
+ template <typename FnType>
static void ForEachIndex(const Shape& shape,
tensorflow::gtl::ArraySlice<int64> base,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
- const IndexVisitorFunction& visitor_function);
+ const FnType& visitor_function) {
+ if (ShapeUtil::HasZeroElements(shape)) {
+ return;
+ }
+ CHECK_EQ(Rank(shape), base.size());
+ CHECK_EQ(incr.size(), base.size());
+ CHECK_EQ(count.size(), base.size());
+ const Layout& layout = shape.layout();
+ const int64 rank = layout.minor_to_major_size();
+ // Allows handling R0 arrays, such that the visitor function will be called
+ // once with the proper empty indexes.
+ int64 n = -1;
+ std::vector<int64> indexes(base.begin(), base.end());
+ while (n < rank && visitor_function(indexes)) {
+ // Increments dimensions in minor to major order.
+ for (n = 0; n < rank; ++n) {
+ int64 dim = layout.minor_to_major(n);
+ indexes[dim] += incr[dim];
+ if (indexes[dim] < base[dim] + count[dim]) {
+ break;
+ }
+ indexes[dim] = base[dim];
+ }
+ }
+ }
private:
// Validates all of the non-layout properties of the shape -- this is a helper