aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/BUILD2
-rw-r--r--tensorflow/compiler/xla/index_util.cc19
-rw-r--r--tensorflow/compiler/xla/index_util.h8
-rw-r--r--tensorflow/compiler/xla/literal_util.cc111
-rw-r--r--tensorflow/compiler/xla/literal_util.h49
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc62
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc72
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc41
-rw-r--r--tensorflow/compiler/xla/shape_util.cc25
-rw-r--r--tensorflow/compiler/xla/shape_util.h13
-rw-r--r--tensorflow/compiler/xla/util.h12
11 files changed, 413 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 7576dff0cd..65d4528421 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -256,6 +256,7 @@ cc_library(
":array3d",
":array4d",
":shape_util",
+ ":status_macros",
":types",
":util",
":xla_data_proto",
@@ -274,6 +275,7 @@ cc_test(
":test",
":types",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc
index 92aca3cae9..76c0168f37 100644
--- a/tensorflow/compiler/xla/index_util.cc
+++ b/tensorflow/compiler/xla/index_util.cc
@@ -131,4 +131,23 @@ namespace xla {
return false;
}
+/* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape,
+ int64 dimension) {
+ const Layout& layout = shape.layout();
+ int64 pdim_size = layout.padded_dimensions_size();
+ int64 stride = 1;
+ DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size());
+ for (auto dim : layout.minor_to_major()) {
+ if (dim == dimension) {
+ break;
+ }
+ if (pdim_size == 0) {
+ stride *= shape.dimensions(dim);
+ } else {
+ stride *= layout.padded_dimensions(dim);
+ }
+ }
+ return stride;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h
index e6a26d6220..c9838966a5 100644
--- a/tensorflow/compiler/xla/index_util.h
+++ b/tensorflow/compiler/xla/index_util.h
@@ -61,6 +61,14 @@ class IndexUtil {
static bool BumpIndices(const Shape& shape,
tensorflow::gtl::MutableArraySlice<int64> indices);
+ // Calculates the stride size (in number of elements, not byte size) of a
+ // given logical shape dimension (from 0 to rank-1). If available, padded
+ // dimensions are used.
+ // Example:
+ // GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) ==
+ // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10
+ static int64 GetDimensionStride(const Shape& shape, int64 dimension);
+
private:
TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil);
};
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 0286b0817c..03c9e2c9d7 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -16,12 +16,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include <algorithm>
+#include <functional>
#include <limits>
#include <numeric>
#include <vector>
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -33,6 +35,115 @@ limitations under the License.
namespace xla {
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromShape(
+ const Shape& shape) {
+ auto literal = MakeUnique<Literal>();
+ *literal->mutable_shape() = shape;
+ Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get());
+ return literal;
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
+ PrimitiveType primitive_type,
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
+}
+
+template <typename T, typename WT>
+/* static */ Status LiteralUtil::CopyRange(
+ const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+ Literal* dest_literal, tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
+ const Shape& src_shape = src_literal.shape();
+ const Shape& dest_shape = dest_literal->shape();
+ tensorflow::gtl::ArraySlice<T> src_data = GetArraySlice<T>(src_literal);
+ tensorflow::protobuf::RepeatedField<WT>* dest_data =
+ GetMutableRepeatedField<WT>(dest_literal);
+
+ TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size());
+ TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size());
+ if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) {
+ // If any of the two shapes are scalars, we can just call the StridedCopy()
+ // directly, and we know we will be copying only one value.
+ TF_RET_CHECK(copy_size.empty());
+ StridedCopy(dest_data, LinearIndex(*dest_literal, dest_base), 0, src_data,
+ LinearIndex(src_literal, src_base), 0, 1);
+ } else if (!ShapeUtil::HasZeroElements(dest_shape)) {
+ TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape));
+ TF_RET_CHECK(src_base.size() == dest_base.size());
+ TF_RET_CHECK(src_base.size() == copy_size.size());
+
+ // Scan the source from minor, stepping in copy size blocks, then within
+ // the index enumaration functor, do a strided copy advancing source index
+ // by one (walking through the minor dimension), and destination index by
+ // proper stride size at the matching dimension.
+ std::vector<int64> src_indexes(src_base.size(), 0);
+ std::vector<int64> dest_indexes(dest_base.size(), 0);
+ std::vector<int64> base(src_base.size(), 0);
+ std::vector<int64> incr(src_base.size(), 1);
+ int64 sdim = src_shape.layout().minor_to_major()[0];
+ int64 dest_stride = IndexUtil::GetDimensionStride(dest_shape, sdim);
+
+ incr[sdim] = copy_size[sdim];
+ auto copy_proc = [&](const std::vector<int64>& indexes) {
+ // Map from multi-dimensional index, to source index.
+ std::copy(indexes.begin(), indexes.end(), src_indexes.begin());
+ std::transform(src_indexes.begin(), src_indexes.end(), src_base.begin(),
+ src_indexes.begin(), std::plus<int64>());
+ // Map from multi-dimensional index, to destination index.
+ std::copy(indexes.begin(), indexes.end(), dest_indexes.begin());
+ std::transform(dest_indexes.begin(), dest_indexes.end(),
+ dest_base.begin(), dest_indexes.begin(),
+ std::plus<int64>());
+
+ int64 src_index = LinearIndex(src_literal, src_indexes);
+ int64 dest_index = LinearIndex(*dest_literal, dest_indexes);
+
+ StridedCopy(dest_data, dest_index, dest_stride, src_data, src_index, 1,
+ copy_size[sdim]);
+ return true;
+ };
+
+ ShapeUtil::ForEachIndex(src_shape, base, copy_size, incr, copy_proc);
+ }
+ return Status::OK();
+}
+
+/* static */ Status LiteralUtil::Copy(
+ const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+ Literal* dest_literal, tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
+ TF_RET_CHECK(
+ ShapeUtil::SameElementType(src_literal.shape(), dest_literal->shape()));
+ switch (src_literal.shape().element_type()) {
+ case U32:
+ return CopyRange<uint32>(src_literal, src_base, dest_literal, dest_base,
+ copy_size);
+ case U64:
+ return CopyRange<uint64, tensorflow::protobuf_uint64>(
+ src_literal, src_base, dest_literal, dest_base, copy_size);
+ case S32:
+ return CopyRange<int32>(src_literal, src_base, dest_literal, dest_base,
+ copy_size);
+ case S64:
+ return CopyRange<int64, tensorflow::protobuf_int64>(
+ src_literal, src_base, dest_literal, dest_base, copy_size);
+ case F32:
+ return CopyRange<float>(src_literal, src_base, dest_literal, dest_base,
+ copy_size);
+ case F64:
+ return CopyRange<double>(src_literal, src_base, dest_literal, dest_base,
+ copy_size);
+ case PRED:
+ return CopyRange<bool>(src_literal, src_base, dest_literal, dest_base,
+ copy_size);
+ default:
+ break;
+ }
+ return Unimplemented("Unhandled primitive type %d",
+ src_literal.shape().element_type());
+}
+
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index ef78b819e3..ae3d43e56c 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -100,6 +100,31 @@ class LiteralUtil {
values,
const Layout& layout);
+ // Create a new Literal object with the shape specified as parameter.
+ // The content of the literal values is the default value of the primitive
+ // type of literal itself (0 for numeric types, and false for predicates).
+ static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+
+ // Create a new Literal object with its values havings the primitive_type
+ // type, and with dimensions defined by the dimensions parameter.
+ // The content of the literal values is the default value of the primitive
+ // type of literal itself (0 for numeric types, and false for predicates).
+ static std::unique_ptr<Literal> CreateFromDimensions(
+ PrimitiveType primitive_type,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Copies the values from src_literal, starting at src_base shape indexes,
+ // to dest_literal, starting at dest_base, where the copy size in each
+ // dimension is specified by copy_size.
+ // The src_literal and dest_literal must have the same primitive type,
+ // src_base+copy_size must fit the source literal dimensions, as well as
+ // dest_base+copy_size must fit the destination literal dimensions.
+ static Status Copy(const Literal& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ Literal* dest_literal,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
// Creates a new value that has the equivalent value as literal, but conforms
// to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major
// dimension layout can be re-layed-out as {1, 0} minor-to-major dimension
@@ -398,6 +423,30 @@ class LiteralUtil {
static int64 LinearIndex(const Literal& literal,
tensorflow::gtl::ArraySlice<int64> multi_index);
+ // Internal template helper for the Copy() API, matching its arguments one by
+ // one.
+ //
+ // The double WT template parameter is pretty ugly, but it comes from one of
+ // the gcc versions used for tests, which seems unable to match templates
+ // types uint64 and int64 with tensorflow::protobuf_uint64 and
+ // tensorflow::protobuf_int64, for the GetArraySlice<>() and
+ // GetMutableRepeatedField<>() APIs.
+ // While for the GetArraySlice<>() case the AsUInt64Slice() and
+ // AsInt64Slice() wrappers are taking care via reinterpret_cast<> of the code
+ // pointer parameters, the protocol buffer repeated fields accessories
+ // return a RepeatedField<> pointer, which is not trivially remappable
+ // (unless pretty ugly API forwarder wrapper).
+ // For that gcc version, this creates a mismatch were either things like the
+ // CopyRange<>() API needs to have both specified, or the Get<>() and Set<>()
+ // APIs having to be called with different types (Get<>() with uint64 and
+ // Set<>() with tensorflow::protobuf_uint64).
+ template <typename T, typename WT = T>
+ static Status CopyRange(const Literal& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ Literal* dest_literal,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil);
};
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 91971c3e24..dd4d820bab 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -648,5 +650,65 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) {
EXPECT_TRUE(LiteralUtil::Equal(*output, *expected));
}
+TEST_F(LiteralUtilTest, Copy) {
+ const int64 dimensions[] = {17, 15, 34, 21};
+ const int64 layouts[][4] = {
+ {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
+ for (const auto& layout : layouts) {
+ Shape shape = ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout);
+ auto blank = LiteralUtil::CreateFromShape(shape);
+ auto source = LiteralUtil::CreateFromShape(shape);
+ const int64 sbase[] = {0, 0, 0, 0};
+ const int64 incr[] = {1, 1, 1, 1};
+ uint32 seqnr = 0;
+ auto init_proc = [&](const std::vector<int64>& indexes) {
+ LiteralUtil::Set(source.get(), indexes, ++seqnr);
+ return true;
+ };
+
+ ShapeUtil::ForEachIndex(source->shape(), sbase, dimensions, incr,
+ init_proc);
+
+ const int64 src_base[] = {3, 1, 5, 7};
+ const int64 dest_base[] = {6, 4, 12, 2};
+ const int64 copy_size[] = {7, 8, 11, 9};
+
+ TF_EXPECT_OK(LiteralUtil::Copy(*source, src_base, blank.get(), dest_base,
+ copy_size));
+ std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
+ std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
+ bool matched = true;
+ auto check_proc = [&](const std::vector<int64>& indexes) {
+ std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
+ std::transform(source_indexes.begin(), source_indexes.end(), src_base,
+ source_indexes.begin(), std::plus<int64>());
+ std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
+ std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
+ blank_indexes.begin(), std::plus<int64>());
+ auto bval = LiteralUtil::Get<uint32>(*blank, blank_indexes);
+ matched = (bval != 0 &&
+ bval == LiteralUtil::Get<uint32>(*source, source_indexes));
+ return matched;
+ };
+ ShapeUtil::ForEachIndex(source->shape(), sbase, copy_size, incr,
+ check_proc);
+ EXPECT_TRUE(matched);
+ }
+}
+
+TEST_F(LiteralUtilTest, CopyScalars) {
+ auto zero = LiteralUtil::CreateR0<uint32>(0);
+ auto nine = LiteralUtil::CreateR0<uint32>(9);
+ TF_EXPECT_OK(LiteralUtil::Copy(*nine, {}, zero.get(), {}, {}));
+ EXPECT_TRUE(LiteralUtil::Equal(*zero, *nine));
+
+ auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
+ TF_EXPECT_OK(LiteralUtil::Copy(*vect, {5}, zero.get(), {}, {}));
+ EXPECT_EQ(LiteralUtil::Get<uint32>(*zero, {}), 17);
+ TF_EXPECT_OK(LiteralUtil::Copy(*zero, {}, vect.get(), {4}, {}));
+ EXPECT_EQ(LiteralUtil::Get<uint32>(*vect, {4}), 17);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index ee265c6688..38668856a3 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -209,6 +209,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* operand, HloInstruction* max,
HloInstruction* max_operand);
+ // Tries to constant fold a concatenate operation, and returns true if the
+ // operation has been performed. An error status is returned in case of error.
+ StatusOr<bool> TryConcatenateConstantFold(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+
// A Reshape or Broadcast that feeds an element-wise operation with a unique
// non-scalar operand can sink to after the operation.
StatusOr<bool> TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
@@ -301,14 +307,78 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy,
return Status::OK();
}
+StatusOr<bool> AlgebraicSimplifierVisitor::TryConcatenateConstantFold(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ if (operands[0]->opcode() == HloOpcode::kConstant) {
+ // If all the operands of a concatenate are constant, fold them into a
+ // single constant tensor.
+ // The concatenate dimension is going to be the sum of all the concatenate
+ // dimensions.
+ int64 concat_dim = concatenate->dimensions()[0];
+ const Shape& reference_shape = operands[0]->shape();
+ if (ShapeUtil::IsTuple(reference_shape)) {
+ VLOG(5) << "Tuples not currently supported by the concatenate constant"
+ " folder";
+ return false;
+ }
+ int64 rank = ShapeUtil::Rank(reference_shape);
+ std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(),
+ reference_shape.dimensions().end());
+ if (concat_dim < 0) {
+ concat_dim += rank;
+ }
+ for (int64 i = 1; i < operands.size(); ++i) {
+ const Shape& operand_shape = operands[i]->shape();
+ if (operands[i]->opcode() != HloOpcode::kConstant ||
+ ShapeUtil::IsTuple(operand_shape)) {
+ return false;
+ }
+ // Accumulate the concat dimension from all tensors taking part to the
+ // operation.
+ concat_dimensions[concat_dim] +=
+ ShapeUtil::GetDimension(operand_shape, concat_dim);
+ }
+
+ auto literal = LiteralUtil::CreateFromDimensions(
+ reference_shape.element_type(), concat_dimensions);
+ std::vector<int64> source_indices(rank, 0);
+ std::vector<int64> dest_indices(concat_dimensions.size(), 0);
+ for (auto operand : operands) {
+ const Shape& operand_shape = operand->shape();
+ Status status = LiteralUtil::Copy(
+ operand->literal(), source_indices, literal.get(), dest_indices,
+ AsInt64Slice(operand_shape.dimensions()));
+ if (!status.ok()) {
+ VLOG(1) << "Error while creating concatenated literal : " << status;
+ return false;
+ }
+ dest_indices[concat_dim] +=
+ ShapeUtil::GetDimension(operand_shape, concat_dim);
+ }
+ TF_CHECK_OK(computation_->ReplaceWithNewInstruction(
+ concatenate, HloInstruction::CreateConstant(std::move(literal))));
+ changed_ = true;
+ return true;
+ }
+ return false;
+}
+
Status AlgebraicSimplifierVisitor::HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
- // Unary concatenates are useless.
if (operands.size() == 1) {
+ // Unary concatenates are useless.
ReplaceInstructionIfSameShape(concatenate, operands[0]);
return Status::OK();
}
+ // If all the concatenate operands are constant, this will get folded into a
+ // new constant literal.
+ TF_ASSIGN_OR_RETURN(bool folded,
+ TryConcatenateConstantFold(concatenate, operands));
+ if (folded) {
+ return Status::OK();
+ }
// Filter out and remove empty operands.
std::vector<HloInstruction*> nonempty_operands;
for (HloInstruction* operand : operands) {
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 77b8fca1a9..8c28ef30ef 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1620,5 +1620,46 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
}
+TEST_F(AlgebraicSimplifierTest, Concatenate) {
+ const struct TestConfig {
+ int concat_dimension;
+ tensorflow::gtl::ArraySlice<int64> dimensions;
+ tensorflow::gtl::ArraySlice<int64> concat_sizes;
+ } test_configs[] = {
+ {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
+ {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
+ };
+
+ for (auto& test_config : test_configs) {
+ HloComputation::Builder builder(TestName());
+ std::vector<int64> dimensions(test_config.dimensions.begin(),
+ test_config.dimensions.end());
+ int64 concat_size = 0;
+ std::vector<HloInstruction*> operands;
+ for (auto csize : test_config.concat_sizes) {
+ dimensions[test_config.concat_dimension] = csize;
+ concat_size += csize;
+ auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
+ HloInstruction* insn = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+ operands.push_back(insn);
+ }
+ dimensions[test_config.concat_dimension] = concat_size;
+ Shape shape = ShapeUtil::MakeShape(F32, dimensions);
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ shape, operands, test_config.concat_dimension));
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 57d91e4bfc..b558e31ee9 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1047,4 +1047,29 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
return shape;
}
+/* static */ void ShapeUtil::ForEachIndex(
+ const Shape& shape, tensorflow::gtl::ArraySlice<int64> base,
+ tensorflow::gtl::ArraySlice<int64> count,
+ tensorflow::gtl::ArraySlice<int64> incr,
+ const IndexVisitorFunction& visitor_function) {
+ DCHECK_EQ(Rank(shape), base.size());
+ DCHECK_EQ(incr.size(), base.size());
+ DCHECK_EQ(count.size(), base.size());
+ const Layout& layout = shape.layout();
+ int64 rank = layout.minor_to_major_size();
+ int64 n = 0;
+ std::vector<int64> indexes(base.begin(), base.end());
+ while (n < rank && visitor_function(indexes)) {
+ // Increments dimensions in minor to major order.
+ for (n = 0; n < rank; ++n) {
+ int64 dim = layout.minor_to_major(n);
+ indexes[dim] += incr[dim];
+ if (indexes[dim] < base[dim] + count[dim]) {
+ break;
+ }
+ indexes[dim] = base[dim];
+ }
+ }
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 68e138e6ac..a3da36b7c6 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -390,6 +390,19 @@ class ShapeUtil {
static Shape FilterDimensions(const std::function<bool(int64)>& p,
Shape shape);
+ // Iterates through all the shape indexes, in minor to major order, starting
+ // from the base indexes, incrementing by the incr steps, up to count
+ // (index[i] < base[i] + count[i]), and calls the visitor_function with the
+ // current index.
+ // The visitor_function visitor function should return true if it wants to
+ // continue, or false otherwise.
+ using IndexVisitorFunction = std::function<bool(const std::vector<int64>&)>;
+ static void ForEachIndex(const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> base,
+ tensorflow::gtl::ArraySlice<int64> count,
+ tensorflow::gtl::ArraySlice<int64> incr,
+ const IndexVisitorFunction& visitor_function);
+
private:
// Validates all of the non-layout properties of the shape -- this is a helper
// used by both the layout-optional and layout-required public method.
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 32b5fbba00..236728f417 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -139,6 +139,18 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2,
std::equal(std::begin(c1), std::end(c1), std::begin(c2), p));
}
+// Performs a copy of count values from src to dest, using different strides for
+// source and destination. The source starting index is src_base, while the
+// destination one is dest_base.
+template <typename D, typename S>
+void StridedCopy(tensorflow::protobuf::RepeatedField<D>* dest, int64 dest_base,
+ int64 dest_stride, tensorflow::gtl::ArraySlice<S> src,
+ int64 src_base, int64 src_stride, int64 count) {
+ for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) {
+ dest->Set(dest_base, static_cast<D>(src[src_base]));
+ }
+}
+
// Adds some context information to the error message in a
// Status. This is useful as Statuses are
// propagated upwards.