diff options
author | Kay Zhu <kayzhu@google.com> | 2018-08-08 17:16:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 17:24:39 -0700 |
commit | 963ef37203c76e0338ede21b469020e425fb9208 (patch) | |
tree | 3f9773304ba3f3dccdd8189e4962c57c3d283985 /tensorflow/compiler/xla/literal.h | |
parent | 3325275eff98ffddb52a16db932481983a9de9a8 (diff) |
[TF:XLA] Introduce MutableBorrowingLiteral to enable interacting with a (tensor) buffer not owned by XLA/Literal class directly, without having to memcpy the Literal to a (Host)Tensor.
PiperOrigin-RevId: 207972410
Diffstat (limited to 'tensorflow/compiler/xla/literal.h')
-rw-r--r-- | tensorflow/compiler/xla/literal.h | 186 |
1 files changed, 111 insertions, 75 deletions
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index dd67dfa8d4..92c0f903cb 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -310,9 +310,10 @@ class LiteralBase { // type of literal itself (0 for numeric types, and false for predicates). // // Note: It's an antipattern to use this method then immediately call - // Literal::Populate on the result (since that results in zero initialization, - // then reinitialization. Conside if a call to MakeUnique<Literal>(shape), - // followed by the call to Literal::Populate can be used instead. + // MutableLiteralBase::Populate on the result (since that results in zero + // initialization, then reinitialization. Conside if a call to + // MakeUnique<Literal>(shape), followed by the call to + // MutableLiteralBase::Populate can be used instead. static std::unique_ptr<Literal> CreateFromShape(const Shape& shape); protected: @@ -534,7 +535,7 @@ class LiteralBase { virtual const Piece& root_piece() const = 0; // LiteralSlice and Literal must access Pieces of other Literals. - friend class Literal; + friend class MutableLiteralBase; friend class LiteralSlice; friend class BorrowingLiteral; @@ -545,33 +546,10 @@ class LiteralBase { tensorflow::gtl::ArraySlice<int64> start_indices) const; }; -// Class representing literal values in XLA. -// -// The underlying buffer and shape is always owned by this class. -class Literal : public LiteralBase { +// Abstract base class representing a mutable literal in XLA. +class MutableLiteralBase : public LiteralBase { public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); - Literal& operator=(Literal&& other); - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return shape_.get(); } + virtual ~MutableLiteralBase() = 0; // Returns a MutableArraySlice view of the array for this literal for the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the @@ -587,6 +565,10 @@ class Literal : public LiteralBase { // is not a sparse array. SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + // Returns a pointer to the underlying buffer holding the array at the given // shape index. CHECKs if the subshape of the literal at the given ShapeIndex // is not array. @@ -613,21 +595,6 @@ class Literal : public LiteralBase { const ShapeIndex& dest_shape_index = {}, const ShapeIndex& src_shape_index = {}); - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector<Literal> DecomposeTuple(); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - // Copies the values from src_literal, starting at src_base shape indexes, // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. @@ -730,12 +697,7 @@ class Literal : public LiteralBase { static StatusOr<std::unique_ptr<Literal>> CreateFromProto( const LiteralProto& proto); - private: - // Recursively sets the subshapes and buffers of all subpieces rooted at - // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in - // the shape. - void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - + protected: // Returns the piece at the given ShapeIndex. Piece& piece(const ShapeIndex& shape_index) { return const_cast<Piece&>(LiteralBase::piece(shape_index)); @@ -783,12 +745,83 @@ class Literal : public LiteralBase { template <typename NativeT, typename FnType> Status PopulateInternal(const FnType& generator, bool parallel); + friend class LiteralBase; + friend class MutableBorrowingLiteral; +}; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + +// The underlying buffer and shape is always owned by this class. +class Literal : public MutableLiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + virtual Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector<Literal> DecomposeTuple(); + + private: // Deallocate the buffers held by this literal. void DeallocateBuffers(); - friend class LiteralBase; + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); +}; + +// The underlying buffer is not owned by this class and is always owned by +// others. The shape is not owned by this class and not mutable. +class MutableBorrowingLiteral : public MutableLiteralBase { + public: + virtual ~MutableBorrowingLiteral(); + + MutableBorrowingLiteral() : MutableLiteralBase() {} + + MutableBorrowingLiteral(const MutableBorrowingLiteral& literal); + MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal); + + // Implicit conversion constructors. + MutableBorrowingLiteral(const MutableLiteralBase& literal); + MutableBorrowingLiteral(MutableLiteralBase* literal); + MutableBorrowingLiteral(MutableBorrowingLiteral literal, + const ShapeIndex& view_root); + MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + + private: + // Recursively copies the subtree from the `src_piece` at the given child + // index to the `dest_piece`. For buffers only the pointers are copied, but + // not the content. + void CopyPieceSubtree(const Shape& shape, Piece* src_piece, + Piece* dest_piece); }; -std::ostream& operator<<(std::ostream& out, const Literal& literal); // A read-only view of a Literal. A LiteralSlice contains pointers to shape and // literal buffers always owned by others. @@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. Stored as unique_ptr so such that the (default) - // move construction of this class would be trivially correct: the pointer to - // Shape root_piece_ stores will still point to the correct address. + // Shape of this literal. Stored as unique_ptr such that the (default) move + // construction of this class would be trivially correct: the pointer to Shape + // root_piece_ stores will still point to the correct address. std::unique_ptr<Shape> shape_; }; @@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data( } template <typename NativeT> -tensorflow::gtl::MutableArraySlice<NativeT> Literal::data( +tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data( const ShapeIndex& shape_index) { return piece(shape_index).data<NativeT>(); } @@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get( } template <typename NativeT> -inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, - const ShapeIndex& shape_index, NativeT value) { +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice<int64> multi_index, + const ShapeIndex& shape_index, NativeT value) { return piece(shape_index).Set<NativeT>(multi_index, value); } template <typename NativeT> -inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, - NativeT value) { +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) { return root_piece().Set<NativeT>(multi_index, value); } @@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, } template <typename NativeT> -void Literal::AppendSparseElement( +void MutableLiteralBase::AppendSparseElement( tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value, const ShapeIndex& shape_index) { Piece& p = piece(shape_index); @@ -959,7 +993,8 @@ void LiteralBase::EachCell( } template <typename NativeT> -inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) { +inline void MutableLiteralBase::PopulateR1( + tensorflow::gtl::ArraySlice<NativeT> values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); @@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) { } template <typename NativeT> -void Literal::PopulateR2( +void MutableLiteralBase::PopulateR2( std::initializer_list<std::initializer_list<NativeT>> values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 2); @@ -996,7 +1031,7 @@ void Literal::PopulateR2( } template <typename NativeT> -void Literal::PopulateFromArray(const Array<NativeT>& values) { +void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType<NativeT>()); @@ -1009,24 +1044,24 @@ void Literal::PopulateFromArray(const Array<NativeT>& values) { } template <typename NativeT> -void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) { +void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) { PopulateFromArray(values); } template <typename NativeT> -void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) { +void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) { PopulateFromArray(values); } template <typename NativeT> -void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) { +void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) { PopulateFromArray(values); } template <typename NativeT> -void Literal::PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice<NativeT> values, - bool sort) { +void MutableLiteralBase::PopulateSparse( + SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values, + bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); int rank = ShapeUtil::Rank(shape()); CHECK_EQ(indices.rank(), rank); @@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices, } template <typename NativeT, typename FnType> -Status Literal::PopulateInternal(const FnType& generator, bool parallel) { +Status MutableLiteralBase::PopulateInternal(const FnType& generator, + bool parallel) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); @@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) { return Status::OK(); } template <typename NativeT, typename FnType> -Status Literal::Populate(const FnType& generator) { +Status MutableLiteralBase::Populate(const FnType& generator) { return PopulateInternal<NativeT>(generator, /*parallel=*/false); } template <typename NativeT, typename FnType> -Status Literal::PopulateParallel(const FnType& generator) { +Status MutableLiteralBase::PopulateParallel(const FnType& generator) { return PopulateInternal<NativeT>(generator, /*parallel=*/true); } template <typename NativeT> -void Literal::PopulateWithValue(NativeT value) { +void MutableLiteralBase::PopulateWithValue(NativeT value) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType<NativeT>()); |