diff options
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/index_util.cc | 19 | ||||
-rw-r--r-- | tensorflow/compiler/xla/index_util.h | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 111 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util.h | 49 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util_test.cc | 62 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 72 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 41 | ||||
-rw-r--r-- | tensorflow/compiler/xla/shape_util.cc | 25 | ||||
-rw-r--r-- | tensorflow/compiler/xla/shape_util.h | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/util.h | 12 |
11 files changed, 413 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 7576dff0cd..65d4528421 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -256,6 +256,7 @@ cc_library( ":array3d", ":array4d", ":shape_util", + ":status_macros", ":types", ":util", ":xla_data_proto", @@ -274,6 +275,7 @@ cc_test( ":test", ":types", "//tensorflow/core:lib", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 92aca3cae9..76c0168f37 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -131,4 +131,23 @@ namespace xla { return false; } +/* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, + int64 dimension) { + const Layout& layout = shape.layout(); + int64 pdim_size = layout.padded_dimensions_size(); + int64 stride = 1; + DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); + for (auto dim : layout.minor_to_major()) { + if (dim == dimension) { + break; + } + if (pdim_size == 0) { + stride *= shape.dimensions(dim); + } else { + stride *= layout.padded_dimensions(dim); + } + } + return stride; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index e6a26d6220..c9838966a5 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -61,6 +61,14 @@ class IndexUtil { static bool BumpIndices(const Shape& shape, tensorflow::gtl::MutableArraySlice<int64> indices); + // Calculates the stride size (in number of elements, not byte size) of a + // given logical shape dimension (from 0 to rank-1). If available, padded + // dimensions are used. + // Example: + // GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) == + // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 + static int64 GetDimensionStride(const Shape& shape, int64 dimension); + private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); }; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 0286b0817c..03c9e2c9d7 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -16,12 +16,14 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include <algorithm> +#include <functional> #include <limits> #include <numeric> #include <vector> #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -33,6 +35,115 @@ limitations under the License. namespace xla { +/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromShape( + const Shape& shape) { + auto literal = MakeUnique<Literal>(); + *literal->mutable_shape() = shape; + Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get()); + return literal; +} + +/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice<int64> dimensions) { + return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); +} + +template <typename T, typename WT> +/* static */ Status LiteralUtil::CopyRange( + const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, + Literal* dest_literal, tensorflow::gtl::ArraySlice<int64> dest_base, + tensorflow::gtl::ArraySlice<int64> copy_size) { + const Shape& src_shape = src_literal.shape(); + const Shape& dest_shape = dest_literal->shape(); + tensorflow::gtl::ArraySlice<T> src_data = GetArraySlice<T>(src_literal); + tensorflow::protobuf::RepeatedField<WT>* dest_data = + GetMutableRepeatedField<WT>(dest_literal); + + TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); + if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) { + // If any of the two shapes are scalars, we can just call the StridedCopy() + // directly, and we know we will be copying only one value. + TF_RET_CHECK(copy_size.empty()); + StridedCopy(dest_data, LinearIndex(*dest_literal, dest_base), 0, src_data, + LinearIndex(src_literal, src_base), 0, 1); + } else if (!ShapeUtil::HasZeroElements(dest_shape)) { + TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape)); + TF_RET_CHECK(src_base.size() == dest_base.size()); + TF_RET_CHECK(src_base.size() == copy_size.size()); + + // Scan the source from minor, stepping in copy size blocks, then within + // the index enumaration functor, do a strided copy advancing source index + // by one (walking through the minor dimension), and destination index by + // proper stride size at the matching dimension. + std::vector<int64> src_indexes(src_base.size(), 0); + std::vector<int64> dest_indexes(dest_base.size(), 0); + std::vector<int64> base(src_base.size(), 0); + std::vector<int64> incr(src_base.size(), 1); + int64 sdim = src_shape.layout().minor_to_major()[0]; + int64 dest_stride = IndexUtil::GetDimensionStride(dest_shape, sdim); + + incr[sdim] = copy_size[sdim]; + auto copy_proc = [&](const std::vector<int64>& indexes) { + // Map from multi-dimensional index, to source index. + std::copy(indexes.begin(), indexes.end(), src_indexes.begin()); + std::transform(src_indexes.begin(), src_indexes.end(), src_base.begin(), + src_indexes.begin(), std::plus<int64>()); + // Map from multi-dimensional index, to destination index. + std::copy(indexes.begin(), indexes.end(), dest_indexes.begin()); + std::transform(dest_indexes.begin(), dest_indexes.end(), + dest_base.begin(), dest_indexes.begin(), + std::plus<int64>()); + + int64 src_index = LinearIndex(src_literal, src_indexes); + int64 dest_index = LinearIndex(*dest_literal, dest_indexes); + + StridedCopy(dest_data, dest_index, dest_stride, src_data, src_index, 1, + copy_size[sdim]); + return true; + }; + + ShapeUtil::ForEachIndex(src_shape, base, copy_size, incr, copy_proc); + } + return Status::OK(); +} + +/* static */ Status LiteralUtil::Copy( + const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, + Literal* dest_literal, tensorflow::gtl::ArraySlice<int64> dest_base, + tensorflow::gtl::ArraySlice<int64> copy_size) { + TF_RET_CHECK( + ShapeUtil::SameElementType(src_literal.shape(), dest_literal->shape())); + switch (src_literal.shape().element_type()) { + case U32: + return CopyRange<uint32>(src_literal, src_base, dest_literal, dest_base, + copy_size); + case U64: + return CopyRange<uint64, tensorflow::protobuf_uint64>( + src_literal, src_base, dest_literal, dest_base, copy_size); + case S32: + return CopyRange<int32>(src_literal, src_base, dest_literal, dest_base, + copy_size); + case S64: + return CopyRange<int64, tensorflow::protobuf_int64>( + src_literal, src_base, dest_literal, dest_base, copy_size); + case F32: + return CopyRange<float>(src_literal, src_base, dest_literal, dest_base, + copy_size); + case F64: + return CopyRange<double>(src_literal, src_base, dest_literal, dest_base, + copy_size); + case PRED: + return CopyRange<bool>(src_literal, src_base, dest_literal, dest_base, + copy_size); + default: + break; + } + return Unimplemented("Unhandled primitive type %d", + src_literal.shape().element_type()); +} + /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index ef78b819e3..ae3d43e56c 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -100,6 +100,31 @@ class LiteralUtil { values, const Layout& layout); + // Create a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr<Literal> CreateFromShape(const Shape& shape); + + // Create a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr<Literal> CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice<int64> dimensions); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to dest_literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and dest_literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + static Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice<int64> src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice<int64> dest_base, + tensorflow::gtl::ArraySlice<int64> copy_size); + // Creates a new value that has the equivalent value as literal, but conforms // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major // dimension layout can be re-layed-out as {1, 0} minor-to-major dimension @@ -398,6 +423,30 @@ class LiteralUtil { static int64 LinearIndex(const Literal& literal, tensorflow::gtl::ArraySlice<int64> multi_index); + // Internal template helper for the Copy() API, matching its arguments one by + // one. + // + // The double WT template parameter is pretty ugly, but it comes from one of + // the gcc versions used for tests, which seems unable to match templates + // types uint64 and int64 with tensorflow::protobuf_uint64 and + // tensorflow::protobuf_int64, for the GetArraySlice<>() and + // GetMutableRepeatedField<>() APIs. + // While for the GetArraySlice<>() case the AsUInt64Slice() and + // AsInt64Slice() wrappers are taking care via reinterpret_cast<> of the code + // pointer parameters, the protocol buffer repeated fields accessories + // return a RepeatedField<> pointer, which is not trivially remappable + // (unless pretty ugly API forwarder wrapper). + // For that gcc version, this creates a mismatch were either things like the + // CopyRange<>() API needs to have both specified, or the Get<>() and Set<>() + // APIs having to be called with different types (Get<>() with uint64 and + // Set<>() with tensorflow::protobuf_uint64). + template <typename T, typename WT = T> + static Status CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice<int64> src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice<int64> dest_base, + tensorflow::gtl::ArraySlice<int64> copy_size); + TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); }; diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 91971c3e24..dd4d820bab 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -648,5 +650,65 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) { EXPECT_TRUE(LiteralUtil::Equal(*output, *expected)); } +TEST_F(LiteralUtilTest, Copy) { + const int64 dimensions[] = {17, 15, 34, 21}; + const int64 layouts[][4] = { + {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}}; + for (const auto& layout : layouts) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout); + auto blank = LiteralUtil::CreateFromShape(shape); + auto source = LiteralUtil::CreateFromShape(shape); + const int64 sbase[] = {0, 0, 0, 0}; + const int64 incr[] = {1, 1, 1, 1}; + uint32 seqnr = 0; + auto init_proc = [&](const std::vector<int64>& indexes) { + LiteralUtil::Set(source.get(), indexes, ++seqnr); + return true; + }; + + ShapeUtil::ForEachIndex(source->shape(), sbase, dimensions, incr, + init_proc); + + const int64 src_base[] = {3, 1, 5, 7}; + const int64 dest_base[] = {6, 4, 12, 2}; + const int64 copy_size[] = {7, 8, 11, 9}; + + TF_EXPECT_OK(LiteralUtil::Copy(*source, src_base, blank.get(), dest_base, + copy_size)); + std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0); + std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0); + bool matched = true; + auto check_proc = [&](const std::vector<int64>& indexes) { + std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); + std::transform(source_indexes.begin(), source_indexes.end(), src_base, + source_indexes.begin(), std::plus<int64>()); + std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); + std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, + blank_indexes.begin(), std::plus<int64>()); + auto bval = LiteralUtil::Get<uint32>(*blank, blank_indexes); + matched = (bval != 0 && + bval == LiteralUtil::Get<uint32>(*source, source_indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(source->shape(), sbase, copy_size, incr, + check_proc); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, CopyScalars) { + auto zero = LiteralUtil::CreateR0<uint32>(0); + auto nine = LiteralUtil::CreateR0<uint32>(9); + TF_EXPECT_OK(LiteralUtil::Copy(*nine, {}, zero.get(), {}, {})); + EXPECT_TRUE(LiteralUtil::Equal(*zero, *nine)); + + auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21}); + TF_EXPECT_OK(LiteralUtil::Copy(*vect, {5}, zero.get(), {}, {})); + EXPECT_EQ(LiteralUtil::Get<uint32>(*zero, {}), 17); + TF_EXPECT_OK(LiteralUtil::Copy(*zero, {}, vect.get(), {4}, {})); + EXPECT_EQ(LiteralUtil::Get<uint32>(*vect, {4}), 17); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index ee265c6688..38668856a3 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -209,6 +209,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand); + // Tries to constant fold a concatenate operation, and returns true if the + // operation has been performed. An error status is returned in case of error. + StatusOr<bool> TryConcatenateConstantFold( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice<HloInstruction*> operands); + // A Reshape or Broadcast that feeds an element-wise operation with a unique // non-scalar operand can sink to after the operation. StatusOr<bool> TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( @@ -301,14 +307,78 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, return Status::OK(); } +StatusOr<bool> AlgebraicSimplifierVisitor::TryConcatenateConstantFold( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice<HloInstruction*> operands) { + if (operands[0]->opcode() == HloOpcode::kConstant) { + // If all the operands of a concatenate are constant, fold them into a + // single constant tensor. + // The concatenate dimension is going to be the sum of all the concatenate + // dimensions. + int64 concat_dim = concatenate->dimensions()[0]; + const Shape& reference_shape = operands[0]->shape(); + if (ShapeUtil::IsTuple(reference_shape)) { + VLOG(5) << "Tuples not currently supported by the concatenate constant" + " folder"; + return false; + } + int64 rank = ShapeUtil::Rank(reference_shape); + std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(), + reference_shape.dimensions().end()); + if (concat_dim < 0) { + concat_dim += rank; + } + for (int64 i = 1; i < operands.size(); ++i) { + const Shape& operand_shape = operands[i]->shape(); + if (operands[i]->opcode() != HloOpcode::kConstant || + ShapeUtil::IsTuple(operand_shape)) { + return false; + } + // Accumulate the concat dimension from all tensors taking part to the + // operation. + concat_dimensions[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + auto literal = LiteralUtil::CreateFromDimensions( + reference_shape.element_type(), concat_dimensions); + std::vector<int64> source_indices(rank, 0); + std::vector<int64> dest_indices(concat_dimensions.size(), 0); + for (auto operand : operands) { + const Shape& operand_shape = operand->shape(); + Status status = LiteralUtil::Copy( + operand->literal(), source_indices, literal.get(), dest_indices, + AsInt64Slice(operand_shape.dimensions())); + if (!status.ok()) { + VLOG(1) << "Error while creating concatenated literal : " << status; + return false; + } + dest_indices[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + TF_CHECK_OK(computation_->ReplaceWithNewInstruction( + concatenate, HloInstruction::CreateConstant(std::move(literal)))); + changed_ = true; + return true; + } + return false; +} + Status AlgebraicSimplifierVisitor::HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice<HloInstruction*> operands) { - // Unary concatenates are useless. if (operands.size() == 1) { + // Unary concatenates are useless. ReplaceInstructionIfSameShape(concatenate, operands[0]); return Status::OK(); } + // If all the concatenate operands are constant, this will get folded into a + // new constant literal. + TF_ASSIGN_OR_RETURN(bool folded, + TryConcatenateConstantFold(concatenate, operands)); + if (folded) { + return Status::OK(); + } // Filter out and remove empty operands. std::vector<HloInstruction*> nonempty_operands; for (HloInstruction* operand : operands) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 77b8fca1a9..8c28ef30ef 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1620,5 +1620,46 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); } +TEST_F(AlgebraicSimplifierTest, Concatenate) { + const struct TestConfig { + int concat_dimension; + tensorflow::gtl::ArraySlice<int64> dimensions; + tensorflow::gtl::ArraySlice<int64> concat_sizes; + } test_configs[] = { + {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, + {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, + }; + + for (auto& test_config : test_configs) { + HloComputation::Builder builder(TestName()); + std::vector<int64> dimensions(test_config.dimensions.begin(), + test_config.dimensions.end()); + int64 concat_size = 0; + std::vector<HloInstruction*> operands; + for (auto csize : test_config.concat_sizes) { + dimensions[test_config.concat_dimension] = csize; + concat_size += csize; + auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); + HloInstruction* insn = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + operands.push_back(insn); + } + dimensions[test_config.concat_dimension] = concat_size; + Shape shape = ShapeUtil::MakeShape(F32, dimensions); + builder.AddInstruction(HloInstruction::CreateConcatenate( + shape, operands, test_config.concat_dimension)); + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConstant); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 57d91e4bfc..b558e31ee9 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1047,4 +1047,29 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } +/* static */ void ShapeUtil::ForEachIndex( + const Shape& shape, tensorflow::gtl::ArraySlice<int64> base, + tensorflow::gtl::ArraySlice<int64> count, + tensorflow::gtl::ArraySlice<int64> incr, + const IndexVisitorFunction& visitor_function) { + DCHECK_EQ(Rank(shape), base.size()); + DCHECK_EQ(incr.size(), base.size()); + DCHECK_EQ(count.size(), base.size()); + const Layout& layout = shape.layout(); + int64 rank = layout.minor_to_major_size(); + int64 n = 0; + std::vector<int64> indexes(base.begin(), base.end()); + while (n < rank && visitor_function(indexes)) { + // Increments dimensions in minor to major order. + for (n = 0; n < rank; ++n) { + int64 dim = layout.minor_to_major(n); + indexes[dim] += incr[dim]; + if (indexes[dim] < base[dim] + count[dim]) { + break; + } + indexes[dim] = base[dim]; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 68e138e6ac..a3da36b7c6 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -390,6 +390,19 @@ class ShapeUtil { static Shape FilterDimensions(const std::function<bool(int64)>& p, Shape shape); + // Iterates through all the shape indexes, in minor to major order, starting + // from the base indexes, incrementing by the incr steps, up to count + // (index[i] < base[i] + count[i]), and calls the visitor_function with the + // current index. + // The visitor_function visitor function should return true if it wants to + // continue, or false otherwise. + using IndexVisitorFunction = std::function<bool(const std::vector<int64>&)>; + static void ForEachIndex(const Shape& shape, + tensorflow::gtl::ArraySlice<int64> base, + tensorflow::gtl::ArraySlice<int64> count, + tensorflow::gtl::ArraySlice<int64> incr, + const IndexVisitorFunction& visitor_function); + private: // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 32b5fbba00..236728f417 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -139,6 +139,18 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2, std::equal(std::begin(c1), std::end(c1), std::begin(c2), p)); } +// Performs a copy of count values from src to dest, using different strides for +// source and destination. The source starting index is src_base, while the +// destination one is dest_base. +template <typename D, typename S> +void StridedCopy(tensorflow::protobuf::RepeatedField<D>* dest, int64 dest_base, + int64 dest_stride, tensorflow::gtl::ArraySlice<S> src, + int64 src_base, int64 src_stride, int64 count) { + for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) { + dest->Set(dest_base, static_cast<D>(src[src_base])); + } +} + // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. |