aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2018-06-07 13:03:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 13:06:34 -0700
commit5174b67f70645210429db837df3047c7d52637bf (patch)
tree28ee7e0e57e96da63ea7206c7be815f084aa037c
parent09c25a87cf321f317662f67d1b08deb3585e9abe (diff)
[TF:XLA] Introduce a new HostTensorToBorrowingLiteral path without the memcpy from Tensor to Literal, and use it in xla_helpers.
PiperOrigin-RevId: 199682452
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc31
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h12
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc11
-rw-r--r--tensorflow/compiler/xla/literal_util.cc22
-rw-r--r--tensorflow/compiler/xla/literal_util.h6
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc4
6 files changed, 67 insertions, 19 deletions
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index 43e1c1e9fe..db56b12837 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -40,6 +40,37 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
return Status::OK();
}
+Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
+ xla::BorrowingLiteral* literal) {
+ xla::Shape xla_shape;
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
+ host_tensor.shape(), &xla_shape));
+ *literal = xla::BorrowingLiteral(
+ static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
+ return Status::OK();
+}
+
+Status HostTensorsToBorrowingLiteralTuple(
+ tensorflow::gtl::ArraySlice<Tensor> host_tensors,
+ xla::BorrowingLiteral* literal) {
+ std::vector<const char*> buf_ptrs;
+ buf_ptrs.reserve(host_tensors.size());
+ std::vector<xla::Shape> tensor_shapes(host_tensors.size());
+
+ for (int i = 0; i < host_tensors.size(); i++) {
+ // Validate runtime shapes and fail if it doesn't match the contract.
+ const Tensor* tensor = &host_tensors[i];
+ buf_ptrs.emplace_back(static_cast<const char*>(DMAHelper::base(tensor)));
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(),
+ &tensor_shapes[i]));
+ }
+
+ *literal = xla::BorrowingLiteral(
+ buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes));
+
+ return Status::OK();
+}
+
Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
Tensor* host_tensor) {
TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) &&
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index 220bec1553..74685025c1 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -29,6 +30,17 @@ namespace tensorflow {
// unsupported type.
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
+// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by
+// 'host_tensor'.
+Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
+ xla::BorrowingLiteral* literal);
+
+// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
+// owned by 'host_tensors'.
+Status HostTensorsToBorrowingLiteralTuple(
+ tensorflow::gtl::ArraySlice<Tensor> host_tensors,
+ xla::BorrowingLiteral* literal);
+
// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
// type <target_type>.
// Fails if the literal's primitive type !=
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index f1594193af..a1da176fe3 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -19,11 +19,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -210,8 +212,9 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(dtype));
}
- xla::Literal linspace_literal;
- TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
+ xla::BorrowingLiteral linspace_literal;
+ TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
+
*iota = builder->ConstantLiteral(linspace_literal);
return Status::OK();
}
@@ -245,8 +248,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(index_type));
}
- xla::Literal linspace_literal;
- TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
+ xla::BorrowingLiteral linspace_literal;
+ TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
// Broadcast the linspace constant across the indices along the new axis,
// and test equality at each position.
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 61afc311a7..6b29589700 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -2341,28 +2341,28 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal,
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
- : LiteralBase(), shape_(shape) {
- CHECK(ShapeUtil::IsArray(shape_));
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(ShapeUtil::IsArray(*shape_));
CHECK_NE(src_buf_ptr, nullptr);
- CHECK(LayoutUtil::HasLayout(shape_));
+ CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = Piece();
root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
- root_piece_.set_subshape(&shape_);
+ root_piece_.set_subshape(shape_.get());
}
BorrowingLiteral::BorrowingLiteral(
tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
- : LiteralBase(), shape_(shape) {
- CHECK(ShapeUtil::IsTuple(shape_));
- CHECK(!ShapeUtil::IsNestedTuple(shape_));
- CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_));
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(ShapeUtil::IsTuple(*shape_));
+ CHECK(!ShapeUtil::IsNestedTuple(*shape_));
+ CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
root_piece_ = Piece();
- root_piece_.set_subshape(&shape_);
- BuildPieceSubtree(shape_, &root_piece_);
+ root_piece_.set_subshape(shape_.get());
+ BuildPieceSubtree(*shape_, &root_piece_);
for (int i = 0; i < src_buf_ptrs.size(); ++i) {
- const auto& src_shape = shape_.tuple_shapes(i);
+ const auto& src_shape = shape_->tuple_shapes(i);
CHECK(ShapeUtil::IsArray(src_shape));
root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
}
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 1e26eb7ad4..8e4159e360 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -1099,8 +1099,10 @@ class BorrowingLiteral : public LiteralBase {
const Piece& root_piece() const override { return root_piece_; };
Piece root_piece_;
- // Shape of this literal.
- const Shape shape_;
+ // Shape of this literal. Stored as unique_ptr so such that the (default)
+ // move construction of this class would be trivially correct: the pointer to
+ // Shape root_piece_ stores will still point to the correct address.
+ std::unique_ptr<Shape> shape_;
};
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index f127cee0fd..53b926163c 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -1431,7 +1431,7 @@ TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
-TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) {
+TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
std::vector<int64> int64_values = {1, 2, 3};
const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
@@ -1443,7 +1443,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) {
EXPECT_EQ(literal.Get<int64>({2}), 3);
}
-TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) {
+TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
std::vector<int64> one_two_three = {1, 2, 3};
const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});