From 253bcbb71bdd1f9f2609b085dce90fe9b31cbd5a Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Thu, 27 Jul 2017 17:59:23 -0700 Subject: [XLA] Use HloEvaluator for convolution in reference_util. Also Speed up HloEvaluator's HandleConvolution in non-opt build, by moving calls to HloInstruction::shape() out of the inner loop. PiperOrigin-RevId: 163416183 --- tensorflow/compiler/xla/BUILD | 3 + tensorflow/compiler/xla/literal_util.h | 19 +- tensorflow/compiler/xla/reference_util.cc | 247 +++++++---------------- tensorflow/compiler/xla/service/hlo_evaluator.cc | 160 ++++++++------- tensorflow/compiler/xla/shape_util.cc | 30 --- tensorflow/compiler/xla/shape_util.h | 31 ++- 6 files changed, 201 insertions(+), 289 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index e0a03a78f1..ba90b13b38 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -563,6 +563,9 @@ cc_library( ":xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_evaluator", + "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 125c268573..e02a96ae70 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -470,10 +470,11 @@ class Literal { // Populates literal values by calling the generator function for every cell // in this literal object. - template - Status Populate( - const std::function indexes)>& - generator); + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + template + Status Populate(const FnType& generator); // Creates a Literal of the given dimensions with all elements set to the // given value. @@ -1107,12 +1108,10 @@ void Literal::PopulateR4FromArray4D(const Array4D& values) { PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } -template -Status Literal::Populate( - const std::function indexes)>& - generator) { +template +Status Literal::Populate(const FnType& generator) { const Shape& this_shape = shape(); - int64 rank = ShapeUtil::Rank(this_shape); + const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); tensorflow::gtl::MutableArraySlice data = @@ -1125,7 +1124,7 @@ Status Literal::Populate( ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); auto init_function = [&](const std::vector& indexes) { - int64 index = LinearIndex(indexes); + const int64 index = LinearIndex(indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); for (int64 i = 0; i < minor_dimension_size; ++i) { minor_scan_indexes[stride_config.minor_dimension] = i; diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 7ef5c6d916..64e197fac5 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -20,6 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" @@ -446,179 +449,85 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( std::pair kernel_stride, Padding padding, std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums) { - std::array lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}}; - std::array rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}}; - - const int64 ksy = kernel_stride.first; - const int64 ksx = kernel_stride.second; - const int64 dy = lhs_dilation.first; - const int64 dx = lhs_dilation.second; - const int64 dky = rhs_dilation.first; - const int64 dkx = rhs_dilation.second; - CHECK_GE(dky, 1); - CHECK_GE(dkx, 1); - CHECK_GE(dy, 1); - CHECK_GE(dx, 1); - - // Get all dimension sizes in lhs and rhs based on the given convolution - // dimension configuration. - const int64 ix = window_util::DilatedBound( - lhs_dimensions[dnums.spatial_dimensions(1)], dx); - const int64 iy = window_util::DilatedBound( - lhs_dimensions[dnums.spatial_dimensions(0)], dy); - const int64 iz = lhs_dimensions[dnums.feature_dimension()]; - const int64 samples = lhs_dimensions[dnums.batch_dimension()]; - const int64 kx = window_util::DilatedBound( - rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx); - const int64 ky = window_util::DilatedBound( - rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky); - const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()]; - { - const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()]; - CHECK_EQ(kiz, iz); + HloComputation::Builder b("ConvArray4DGeneralDimensionDilated"); + auto lhs_literal = Literal::CreateR4FromArray4D(lhs); + auto rhs_literal = Literal::CreateR4FromArray4D(rhs); + + std::array ordered_kernel_strides; + std::array ordered_input_dimensions; + std::array ordered_kernel_dimensions; + if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) { + ordered_kernel_strides[0] = kernel_stride.second; + ordered_kernel_strides[1] = kernel_stride.first; + } else { + ordered_kernel_strides[0] = kernel_stride.first; + ordered_kernel_strides[1] = kernel_stride.second; } - if (padding == Padding::kSame) { - // We reject same padding with kernel striding, since it's somewhat - // nonsensical. We can always follow up to implement this with the desired - // semantics if anybody actually uses it. - CHECK_EQ(1, ksy); - CHECK_EQ(1, ksx); - } - - const int64 ox = - padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx); - const int64 oy = - padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy); - const int64 istartx = - padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2; - const int64 istarty = - padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2; - // Create the output result array and reset the values to 0. - std::array result_dimensions; - result_dimensions[dnums.batch_dimension()] = samples; - result_dimensions[dnums.feature_dimension()] = oz; - result_dimensions[dnums.spatial_dimensions(0)] = oy; - result_dimensions[dnums.spatial_dimensions(1)] = ox; + ordered_input_dimensions[0] = + lhs_literal->shape().dimensions(dnums.spatial_dimensions(0)); + ordered_input_dimensions[1] = + lhs_literal->shape().dimensions(dnums.spatial_dimensions(1)); + ordered_kernel_dimensions[0] = + rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); + ordered_kernel_dimensions[1] = + rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); + + std::vector> paddings = + MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, + ordered_kernel_strides, padding); + CHECK_EQ(paddings.size(), 2); + + Window window; + + WindowDimension dim; + dim.set_size( + rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); + dim.set_stride(kernel_stride.first); + dim.set_padding_low(paddings[0].first); + dim.set_padding_high(paddings[0].second); + dim.set_window_dilation(rhs_dilation.first); + dim.set_base_dilation(lhs_dilation.first); + *window.add_dimensions() = dim; + + WindowDimension dim2; + dim2.set_size( + rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); + dim2.set_stride(kernel_stride.second); + dim2.set_padding_low(paddings[1].first); + dim2.set_padding_high(paddings[1].second); + dim2.set_window_dilation(rhs_dilation.second); + dim2.set_base_dilation(lhs_dilation.second); + *window.add_dimensions() = dim2; + + const Shape& shape = + ShapeInference::InferConvolveShape(lhs_literal->shape(), + rhs_literal->shape(), window, dnums) + .ConsumeValueOrDie(); + + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + HloEvaluator evaluator; + std::unique_ptr result_literal = + evaluator.Evaluate(*b.Build(), {}).ConsumeValueOrDie(); + + CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); auto result = - MakeUnique>(result_dimensions[0], result_dimensions[1], - result_dimensions[2], result_dimensions[3]); - result->Fill(0.0); - - const auto is_int32 = [](int64 x) { - return x >= std::numeric_limits::min() && - x <= std::numeric_limits::max(); - }; - - // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at - // least on x86-64), so we avoid them where possible. - const auto fast_idiv64 = [&](int64 a, int64 b) { - if (is_int32(a) && is_int32(b)) { - return static_cast(static_cast(a) / static_cast(b)); - } - return a / b; - }; - const auto fast_imod64 = [&](int64 a, int64 b) { - if (is_int32(a) && is_int32(b)) { - return static_cast(static_cast(a) % static_cast(b)); - } - return a % b; - }; - - // Lambda to access the lhs operand at the given 4D index. - const auto lhs_element = [&](int64 batch, int64 feature, int64 height, - int64 width) { - if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) { - return 0.0f; - } + MakeUnique>(result_literal->shape().dimensions(0), + result_literal->shape().dimensions(1), + result_literal->shape().dimensions(2), + result_literal->shape().dimensions(3)); + + result->Each([&](tensorflow::gtl::ArraySlice indices, float* value) { + *value = result_literal->Get(indices); + }); - std::array index; - index[dnums.batch_dimension()] = batch; - index[dnums.feature_dimension()] = feature; - index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy); - index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx); - return lhs(index[0], index[1], index[2], index[3]); - }; - - // Lambda to access the rhs operand at the given 4D index. height_over_dky - // should be equal to height / dky, and width_over_dkx should be equal to - // width / dkx. (This is an optimization to avoid doing divisions.) - const auto rhs_element = - [&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height, - int64 width, int64 height_over_dky, int64 width_over_dkx) { - DCHECK_EQ(height % dky, 0); - DCHECK_EQ(width % dkx, 0); - DCHECK_EQ(height / dky, height_over_dky); - DCHECK_EQ(width / dkx, width_over_dkx); - - std::array index; - index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; - index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; - index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; - index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; - return rhs(index[0], index[1], index[2], index[3]); - }; - - // Lambda to access the result data at the given 4D index. - const auto result_element = [&](int64 batch, int64 kernel_output_feature, - int64 height, int64 width) -> float& { - std::array index; - index[dnums.batch_dimension()] = batch; - index[dnums.feature_dimension()] = kernel_output_feature; - index[dnums.spatial_dimensions(0)] = height; - index[dnums.spatial_dimensions(1)] = width; - return (*result)(index[0], index[1], index[2], index[3]); - }; - - for (int64 oyi = 0; oyi < oy; ++oyi) { - for (int64 oxi = 0; oxi < ox; ++oxi) { - for (int64 sample = 0; sample < samples; ++sample) { - for (int64 izi = 0; izi < iz; ++izi) { - for (int64 ozi = 0; ozi < oz; ++ozi) { - for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky; - kyi += dky, kyi_over_dky++) { - for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx; - kxi += dkx, kxi_over_dkx++) { - int64 iyi = istarty + ksy * oyi + kyi; - int64 ixi = istartx + ksx * oxi + kxi; - float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0) - ? 0.0 - : lhs_element(sample, izi, iyi, ixi); - float gain = - rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx); - float addend = input * gain; - result_element(sample, ozi, oyi, oxi) += addend; - } - } - } - } - } - } - } - if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 || - iz == 0) { - LOG(INFO) << "Output will be trivially empty because one of these " - "dimensions is 0: samples: " - << samples << " kx: " << kx << " ky: " << ky << " ox: " << ox - << " oy: " << oy << " oz: " << oz << " iz: " << iz; - return result; - } - bool trivial = true; - auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice indices, - float value) { - if (value != 0.0) { - trivial = false; - } - }; - lhs.Each(check_trivial); - if (trivial) { - LOG(FATAL) << "LHS is all 0.0."; - } - trivial = true; - rhs.Each(check_trivial); - if (trivial) { - LOG(FATAL) << "RHS is all 0.0."; - } return result; } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e2a807595b..55f5504de4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -404,12 +404,16 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override { - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); - CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); - CHECK(ShapeUtil::SameElementType(lhs->shape(), conv->shape())); - TF_CHECK_OK(ShapeUtil::ValidateShape(lhs->shape())); - TF_CHECK_OK(ShapeUtil::ValidateShape(rhs->shape())); + const Shape& result_shape = conv->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); + CHECK(ShapeUtil::IsArray(lhs_shape)); + CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); const auto& dnums = conv->convolution_dimension_numbers(); const int64 num_spatial_dims = dnums.spatial_dimensions_size(); @@ -417,23 +421,23 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { CHECK_GE(num_spatial_dims, 1); CHECK_EQ(window.dimensions_size(), num_spatial_dims); - CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(lhs->shape())); - CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(rhs->shape())); + const auto lhs_rank = ShapeUtil::Rank(lhs_shape); + const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + + CHECK_EQ(num_spatial_dims + 2, lhs_rank); + CHECK_EQ(num_spatial_dims + 2, rhs_rank); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), window, dnums)); - CHECK(ShapeUtil::Compatible(conv->shape(), inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(conv->shape()) + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, + window, dnums)); + CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); - // Dimension number applicable for both input (lhs), and output. const int64 batch_dim = dnums.batch_dimension(); const int64 z_dim = dnums.feature_dimension(); @@ -441,78 +445,78 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs->shape(), z_dim); + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, z_dim); std::vector window_dimension_sizes; for (auto i : dnums.kernel_spatial_dimensions()) { - window_dimension_sizes.push_back( - ShapeUtil::GetDimension(rhs->shape(), i)); + window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); } - const Shape& window_shape = ShapeUtil::MakeShape( - rhs->shape().element_type(), window_dimension_sizes); + const Shape& window_shape = + ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); + + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); + + auto func = [&](tensorflow::gtl::ArraySlice out_index) { + ReturnT result_val = static_cast(0); + + std::fill(lhs_index.begin(), lhs_index.end(), 0); + std::fill(rhs_index.begin(), rhs_index.end(), 0); + std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0); + + lhs_index[batch_dim] = out_index[batch_dim]; + rhs_index[kernel_output_z_dim] = out_index[z_dim]; + + // Convolve input feature with kernel. + do { + for (int64 iz = 0; iz < z_size; ++iz) { + lhs_index[z_dim] = iz; + rhs_index[kernel_input_z_dim] = iz; + + // Find corresponding spatial dimension index for input (lhs). + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 spatial_dim = dnums.spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. + if (undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } - auto result = Literal::CreateFromShape(conv->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice out_index) { - ReturnT result_val = static_cast(0); + // Calculate the actual lhs (input) index after dilation. + lhs_index[spatial_dim] = + undilated_index / window_dim.base_dilation(); - std::vector lhs_index(lhs_rank, 0); - std::vector rhs_index(rhs_rank, 0); - - lhs_index[batch_dim] = out_index[batch_dim]; - rhs_index[kernel_output_z_dim] = out_index[z_dim]; - - std::vector rhs_spatial_index( - dnums.kernel_spatial_dimensions_size(), 0); - - // Convolve input feature with kernel. - do { - for (int64 iz = 0; iz < z_size; ++iz) { - lhs_index[z_dim] = iz; - rhs_index[kernel_input_z_dim] = iz; - - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 spatial_dim = dnums.spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const int64 undilated_index = - out_index[spatial_dim] * window.dimensions(ki).stride() - - window.dimensions(ki).padding_low() + - rhs_spatial_index[ki] * - window.dimensions(ki).window_dilation(); - // Skip if the lhs (input) index is to be dilated. - if (undilated_index % window.dimensions(ki).base_dilation() != - 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. - lhs_index[spatial_dim] = - undilated_index / window.dimensions(ki).base_dilation(); - - // Skip if input index is not in bound. - if (!(lhs_index[spatial_dim] >= 0 && - lhs_index[spatial_dim] < - lhs->shape().dimensions(spatial_dim))) { - goto cnt; - } - - rhs_index[dnums.kernel_spatial_dimensions(ki)] = - rhs_spatial_index[ki]; - } - - result_val += lhs_literal.Get(lhs_index) * - rhs_literal.Get(rhs_index); + // Skip if input index is not in bound. + if (!(lhs_index[spatial_dim] >= 0 && + lhs_index[spatial_dim] < lhs_shape.dimensions(spatial_dim))) { + goto cnt; } - cnt:; - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - return result_val; - })); + rhs_index[dnums.kernel_spatial_dimensions(ki)] = + rhs_spatial_index[ki]; + } + + result_val += lhs_literal.Get(lhs_index) * + rhs_literal.Get(rhs_index); + } + cnt:; + } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + + return result_val; + }; + + auto result = Literal::CreateFromShape(result_shape); + TF_RETURN_IF_ERROR(result->Populate(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 4c079c87d4..745491e485 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1218,34 +1218,4 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } -/* static */ void ShapeUtil::ForEachIndex( - const Shape& shape, tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, - const IndexVisitorFunction& visitor_function) { - if (ShapeUtil::HasZeroElements(shape)) { - return; - } - DCHECK_EQ(Rank(shape), base.size()); - DCHECK_EQ(incr.size(), base.size()); - DCHECK_EQ(count.size(), base.size()); - const Layout& layout = shape.layout(); - int64 rank = layout.minor_to_major_size(); - // Allows handling R0 arrays, such that the visitor function will be called - // once with the proper empty indexes. - int64 n = -1; - std::vector indexes(base.begin(), base.end()); - while (n < rank && visitor_function(indexes)) { - // Increments dimensions in minor to major order. - for (n = 0; n < rank; ++n) { - int64 dim = layout.minor_to_major(n); - indexes[dim] += incr[dim]; - if (indexes[dim] < base[dim] + count[dim]) { - break; - } - indexes[dim] = base[dim]; - } - } -} - } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index fd4adbf34c..b79b88581f 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -421,12 +421,39 @@ class ShapeUtil { // current index. // The visitor_function visitor function should return true if it wants to // continue, or false otherwise. - using IndexVisitorFunction = std::function&)>; + // + // visitor_function must be a callable of type bool(const std::vector&) + // or compatible. + template static void ForEachIndex(const Shape& shape, tensorflow::gtl::ArraySlice base, tensorflow::gtl::ArraySlice count, tensorflow::gtl::ArraySlice incr, - const IndexVisitorFunction& visitor_function); + const FnType& visitor_function) { + if (ShapeUtil::HasZeroElements(shape)) { + return; + } + CHECK_EQ(Rank(shape), base.size()); + CHECK_EQ(incr.size(), base.size()); + CHECK_EQ(count.size(), base.size()); + const Layout& layout = shape.layout(); + const int64 rank = layout.minor_to_major_size(); + // Allows handling R0 arrays, such that the visitor function will be called + // once with the proper empty indexes. + int64 n = -1; + std::vector indexes(base.begin(), base.end()); + while (n < rank && visitor_function(indexes)) { + // Increments dimensions in minor to major order. + for (n = 0; n < rank; ++n) { + int64 dim = layout.minor_to_major(n); + indexes[dim] += incr[dim]; + if (indexes[dim] < base[dim] + count[dim]) { + break; + } + indexes[dim] = base[dim]; + } + } + } private: // Validates all of the non-layout properties of the shape -- this is a helper -- cgit v1.2.3