aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal.h
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2018-08-08 17:16:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 17:24:39 -0700
commit963ef37203c76e0338ede21b469020e425fb9208 (patch)
tree3f9773304ba3f3dccdd8189e4962c57c3d283985 /tensorflow/compiler/xla/literal.h
parent3325275eff98ffddb52a16db932481983a9de9a8 (diff)
[TF:XLA] Introduce MutableBorrowingLiteral to enable interacting with a (tensor) buffer not owned by XLA/Literal class directly, without having to memcpy the Literal to a (Host)Tensor.
PiperOrigin-RevId: 207972410
Diffstat (limited to 'tensorflow/compiler/xla/literal.h')
-rw-r--r--tensorflow/compiler/xla/literal.h186
1 files changed, 111 insertions, 75 deletions
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index dd67dfa8d4..92c0f903cb 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -310,9 +310,10 @@ class LiteralBase {
// type of literal itself (0 for numeric types, and false for predicates).
//
// Note: It's an antipattern to use this method then immediately call
- // Literal::Populate on the result (since that results in zero initialization,
- // then reinitialization. Conside if a call to MakeUnique<Literal>(shape),
- // followed by the call to Literal::Populate can be used instead.
+ // MutableLiteralBase::Populate on the result (since that results in zero
+ // initialization, then reinitialization. Conside if a call to
+ // MakeUnique<Literal>(shape), followed by the call to
+ // MutableLiteralBase::Populate can be used instead.
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
protected:
@@ -534,7 +535,7 @@ class LiteralBase {
virtual const Piece& root_piece() const = 0;
// LiteralSlice and Literal must access Pieces of other Literals.
- friend class Literal;
+ friend class MutableLiteralBase;
friend class LiteralSlice;
friend class BorrowingLiteral;
@@ -545,33 +546,10 @@ class LiteralBase {
tensorflow::gtl::ArraySlice<int64> start_indices) const;
};
-// Class representing literal values in XLA.
-//
-// The underlying buffer and shape is always owned by this class.
-class Literal : public LiteralBase {
+// Abstract base class representing a mutable literal in XLA.
+class MutableLiteralBase : public LiteralBase {
public:
- Literal() : Literal(ShapeUtil::MakeNil()) {}
-
- // Create a literal of the given shape. The literal is allocated sufficient
- // memory to hold the shape. Memory is uninitialized.
- explicit Literal(const Shape& shape);
- virtual ~Literal();
-
- // Literals are moveable, but not copyable. To copy a literal use
- // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
- // of literals which can be expensive.
- Literal(const Literal& other) = delete;
- Literal& operator=(const Literal& other) = delete;
- Literal(Literal&& other);
- // 'allocate_arrays' indicates whether to allocate memory for the arrays in
- // the shape. If false, buffer pointers inside of the Literal::Pieces are set
- // to nullptr.
- Literal(const Shape& shape, bool allocate_arrays);
- Literal& operator=(Literal&& other);
-
- // TODO(b/67651157): Remove this accessor. Literal users should not be able to
- // mutate the shape as this can produce malformed Literals.
- Shape* mutable_shape_do_not_use() { return shape_.get(); }
+ virtual ~MutableLiteralBase() = 0;
// Returns a MutableArraySlice view of the array for this literal for the
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the
@@ -587,6 +565,10 @@ class Literal : public LiteralBase {
// is not a sparse array.
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
+ // TODO(b/67651157): Remove this accessor. Literal users should not be able to
+ // mutate the shape as this can produce malformed Literals.
+ Shape* mutable_shape_do_not_use() { return shape_.get(); }
+
// Returns a pointer to the underlying buffer holding the array at the given
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex
// is not array.
@@ -613,21 +595,6 @@ class Literal : public LiteralBase {
const ShapeIndex& dest_shape_index = {},
const ShapeIndex& src_shape_index = {});
- // Returns a vector containing the tuple elements of this Literal as separate
- // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
- // elements are moved into the new Literals; no data is copied. Upon return
- // this Literal is set to a nil shape (empty tuple)
- std::vector<Literal> DecomposeTuple();
-
- // Similar to CopyFrom, but with move semantincs. The subshape of this literal
- // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
- // (layouts and shapes must match), but need not be arrays. The memory
- // allocated in this literal for the subshape at dest_shape_index is
- // deallocated, and the respective buffers are replaced with those in
- // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
- Status MoveFrom(Literal&& src_literal,
- const ShapeIndex& dest_shape_index = {});
-
// Copies the values from src_literal, starting at src_base shape indexes,
// to this literal, starting at dest_base, where the copy size in each
// dimension is specified by copy_size.
@@ -730,12 +697,7 @@ class Literal : public LiteralBase {
static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
const LiteralProto& proto);
- private:
- // Recursively sets the subshapes and buffers of all subpieces rooted at
- // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
- // the shape.
- void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
-
+ protected:
// Returns the piece at the given ShapeIndex.
Piece& piece(const ShapeIndex& shape_index) {
return const_cast<Piece&>(LiteralBase::piece(shape_index));
@@ -783,12 +745,83 @@ class Literal : public LiteralBase {
template <typename NativeT, typename FnType>
Status PopulateInternal(const FnType& generator, bool parallel);
+ friend class LiteralBase;
+ friend class MutableBorrowingLiteral;
+};
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
+
+// The underlying buffer and shape is always owned by this class.
+class Literal : public MutableLiteralBase {
+ public:
+ Literal() : Literal(ShapeUtil::MakeNil()) {}
+
+ // Create a literal of the given shape. The literal is allocated sufficient
+ // memory to hold the shape. Memory is uninitialized.
+ explicit Literal(const Shape& shape);
+ virtual ~Literal();
+
+ // Literals are moveable, but not copyable. To copy a literal use
+ // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
+ // of literals which can be expensive.
+ Literal(const Literal& other) = delete;
+ Literal& operator=(const Literal& other) = delete;
+ Literal(Literal&& other);
+ // 'allocate_arrays' indicates whether to allocate memory for the arrays in
+ // the shape. If false, buffer pointers inside of the Literal::Pieces are set
+ // to nullptr.
+ Literal(const Shape& shape, bool allocate_arrays);
+ Literal& operator=(Literal&& other);
+
+ // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+ // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+ // (layouts and shapes must match), but need not be arrays. The memory
+ // allocated in this literal for the subshape at dest_shape_index is
+ // deallocated, and the respective buffers are replaced with those in
+ // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+ virtual Status MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index = {});
+
+ // Returns a vector containing the tuple elements of this Literal as separate
+ // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+ // elements are moved into the new Literals; no data is copied. Upon return
+ // this Literal is set to a nil shape (empty tuple)
+ std::vector<Literal> DecomposeTuple();
+
+ private:
// Deallocate the buffers held by this literal.
void DeallocateBuffers();
- friend class LiteralBase;
+ // Recursively sets the subshapes and buffers of all subpieces rooted at
+ // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
+ // the shape.
+ void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
+};
+
+// The underlying buffer is not owned by this class and is always owned by
+// others. The shape is not owned by this class and not mutable.
+class MutableBorrowingLiteral : public MutableLiteralBase {
+ public:
+ virtual ~MutableBorrowingLiteral();
+
+ MutableBorrowingLiteral() : MutableLiteralBase() {}
+
+ MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
+ MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
+
+ // Implicit conversion constructors.
+ MutableBorrowingLiteral(const MutableLiteralBase& literal);
+ MutableBorrowingLiteral(MutableLiteralBase* literal);
+ MutableBorrowingLiteral(MutableBorrowingLiteral literal,
+ const ShapeIndex& view_root);
+ MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
+
+ private:
+ // Recursively copies the subtree from the `src_piece` at the given child
+ // index to the `dest_piece`. For buffers only the pointers are copied, but
+ // not the content.
+ void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
+ Piece* dest_piece);
};
-std::ostream& operator<<(std::ostream& out, const Literal& literal);
// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
// literal buffers always owned by others.
@@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase {
const Piece& root_piece() const override { return root_piece_; };
Piece root_piece_;
- // Shape of this literal. Stored as unique_ptr so such that the (default)
- // move construction of this class would be trivially correct: the pointer to
- // Shape root_piece_ stores will still point to the correct address.
+ // Shape of this literal. Stored as unique_ptr such that the (default) move
+ // construction of this class would be trivially correct: the pointer to Shape
+ // root_piece_ stores will still point to the correct address.
std::unique_ptr<Shape> shape_;
};
@@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
}
template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
+tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data(
const ShapeIndex& shape_index) {
return piece(shape_index).data<NativeT>();
}
@@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get(
}
template <typename NativeT>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value) {
+inline void MutableLiteralBase::Set(
+ tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value) {
return piece(shape_index).Set<NativeT>(multi_index, value);
}
template <typename NativeT>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
+inline void MutableLiteralBase::Set(
+ tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) {
return root_piece().Set<NativeT>(multi_index, value);
}
@@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
}
template <typename NativeT>
-void Literal::AppendSparseElement(
+void MutableLiteralBase::AppendSparseElement(
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
const ShapeIndex& shape_index) {
Piece& p = piece(shape_index);
@@ -959,7 +993,8 @@ void LiteralBase::EachCell(
}
template <typename NativeT>
-inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
+inline void MutableLiteralBase::PopulateR1(
+ tensorflow::gtl::ArraySlice<NativeT> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
@@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
}
template <typename NativeT>
-void Literal::PopulateR2(
+void MutableLiteralBase::PopulateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 2);
@@ -996,7 +1031,7 @@ void Literal::PopulateR2(
}
template <typename NativeT>
-void Literal::PopulateFromArray(const Array<NativeT>& values) {
+void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());
@@ -1009,24 +1044,24 @@ void Literal::PopulateFromArray(const Array<NativeT>& values) {
}
template <typename NativeT>
-void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
+void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
-void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
+void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
-void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
+void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
-void Literal::PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort) {
+void MutableLiteralBase::PopulateSparse(
+ SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values,
+ bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = ShapeUtil::Rank(shape());
CHECK_EQ(indices.rank(), rank);
@@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices,
}
template <typename NativeT, typename FnType>
-Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
+Status MutableLiteralBase::PopulateInternal(const FnType& generator,
+ bool parallel) {
const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
@@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
return Status::OK();
}
template <typename NativeT, typename FnType>
-Status Literal::Populate(const FnType& generator) {
+Status MutableLiteralBase::Populate(const FnType& generator) {
return PopulateInternal<NativeT>(generator, /*parallel=*/false);
}
template <typename NativeT, typename FnType>
-Status Literal::PopulateParallel(const FnType& generator) {
+Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
return PopulateInternal<NativeT>(generator, /*parallel=*/true);
}
template <typename NativeT>
-void Literal::PopulateWithValue(NativeT value) {
+void MutableLiteralBase::PopulateWithValue(NativeT value) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());