From 5174b67f70645210429db837df3047c7d52637bf Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Thu, 7 Jun 2018 13:03:54 -0700 Subject: [TF:XLA] Introduce a new HostTensorToBorrowingLiteral path without the memcpy from Tensor to Literal, and use it in xla_helpers. PiperOrigin-RevId: 199682452 --- tensorflow/compiler/tf2xla/literal_util.cc | 31 ++++++++++++++++++++++++++++ tensorflow/compiler/tf2xla/literal_util.h | 12 +++++++++++ tensorflow/compiler/tf2xla/xla_helpers.cc | 11 ++++++---- tensorflow/compiler/xla/literal_util.cc | 22 ++++++++++---------- tensorflow/compiler/xla/literal_util.h | 6 ++++-- tensorflow/compiler/xla/literal_util_test.cc | 4 ++-- 6 files changed, 67 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 43e1c1e9fe..db56b12837 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,6 +40,37 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + *literal = xla::BorrowingLiteral( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); + return Status::OK(); +} + +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal) { + std::vector buf_ptrs; + buf_ptrs.reserve(host_tensors.size()); + std::vector tensor_shapes(host_tensors.size()); + + for (int i = 0; i < host_tensors.size(); i++) { + // Validate runtime shapes and fail if it doesn't match the contract. + const Tensor* tensor = &host_tensors[i]; + buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(), + &tensor_shapes[i])); + } + + *literal = xla::BorrowingLiteral( + buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes)); + + return Status::OK(); +} + Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 220bec1553..74685025c1 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -29,6 +30,17 @@ namespace tensorflow { // unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); +// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by +// 'host_tensor'. +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); + +// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers +// owned by 'host_tensors'. +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal); + // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . // Fails if the literal's primitive type != diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f1594193af..a1da176fe3 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -19,11 +19,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -210,8 +212,9 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, return errors::InvalidArgument("Invalid argument type ", DataTypeString(dtype)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); + *iota = builder->ConstantLiteral(linspace_literal); return Status::OK(); } @@ -245,8 +248,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 61afc311a7..6b29589700 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -2341,28 +2341,28 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, : LiteralBase(), root_piece_(&literal.piece(view_root)) {} BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsArray(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); CHECK_NE(src_buf_ptr, nullptr); - CHECK(LayoutUtil::HasLayout(shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); root_piece_.set_buffer(const_cast(src_buf_ptr)); - root_piece_.set_subshape(&shape_); + root_piece_.set_subshape(shape_.get()); } BorrowingLiteral::BorrowingLiteral( tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsTuple(shape_)); - CHECK(!ShapeUtil::IsNestedTuple(shape_)); - CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); root_piece_ = Piece(); - root_piece_.set_subshape(&shape_); - BuildPieceSubtree(shape_, &root_piece_); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); for (int i = 0; i < src_buf_ptrs.size(); ++i) { - const auto& src_shape = shape_.tuple_shapes(i); + const auto& src_shape = shape_->tuple_shapes(i); CHECK(ShapeUtil::IsArray(src_shape)); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 1e26eb7ad4..8e4159e360 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -1099,8 +1099,10 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. - const Shape shape_; + // Shape of this literal. Stored as unique_ptr so such that the (default) + // move construction of this class would be trivially correct: the pointer to + // Shape root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; }; template diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index f127cee0fd..53b926163c 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -1431,7 +1431,7 @@ TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { std::vector int64_values = {1, 2, 3}; const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); @@ -1443,7 +1443,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { EXPECT_EQ(literal.Get({2}), 3); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { std::vector one_two_three = {1, 2, 3}; const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); -- cgit v1.2.3