diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal.h')
-rw-r--r-- | tensorflow/compiler/xla/literal.h | 58 |
1 files changed, 22 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index b928cb6374..1e0a2ad0dd 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -217,31 +217,20 @@ class LiteralBase { // Converts this literal to the given shape. Returns an error is the // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr<std::unique_ptr<Literal>> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const; // Converts this literal to another primitive type using a bitcast // conversion. The to and from primitive types must have the same bit // width. Returns an error if the conversion is not possible. This literal // must be array-shaped. - StatusOr<std::unique_ptr<Literal>> BitcastConvert( - PrimitiveType primitive_dest_type) const; + StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const; // Converts this literal to another primitive type. Returns an error if the // conversion is not possible. This literal must be array-shaped. - StatusOr<std::unique_ptr<Literal>> Convert( - PrimitiveType primitive_dest_type) const; + StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const; - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr<Literal>. + // Clones the underlying buffers into a new Literal. Literal Clone() const; - std::unique_ptr<Literal> CloneToUnique() const; // TODO(b/67651157): The methods below which perform computation on Literals // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with @@ -259,24 +248,23 @@ class LiteralBase { // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr<Literal> Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; + Literal Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; // An overload of Relayout which changes the layout of the entire shape rather // than being limited to a single array within the shape. - std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const; + Literal Relayout(const Shape& shape_with_layout) const; // Creates a new literal by reshaping this literal to have the given // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. - StatusOr<std::unique_ptr<Literal>> Reshape( - absl::Span<const int64> dimensions) const; + StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. - StatusOr<std::unique_ptr<Literal>> Broadcast( - const Shape& result_shape, absl::Span<const int64> dimensions) const; + StatusOr<Literal> Broadcast(const Shape& result_shape, + absl::Span<const int64> dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -285,7 +273,7 @@ class LiteralBase { // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. // This literal must be an array. - std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const; + Literal Transpose(absl::Span<const int64> permutation) const; // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the @@ -293,15 +281,15 @@ class LiteralBase { // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. // This literal must be an array. - std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices, - absl::Span<const int64> limit_indices) const; + Literal Slice(absl::Span<const int64> start_indices, + absl::Span<const int64> limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. // This literal must be an array. template <typename NativeT> - std::unique_ptr<Literal> Replicate(int64 times) const; + Literal Replicate(int64 times) const; // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive @@ -312,7 +300,7 @@ class LiteralBase { // initialization, then reinitialization. Conside if a call to // absl::make_unique<Literal>(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. - static std::unique_ptr<Literal> CreateFromShape(const Shape& shape); + static Literal CreateFromShape(const Shape& shape); protected: // A data structure representing a subshape at a particular ShapeIndex within @@ -539,8 +527,8 @@ class LiteralBase { private: template <typename NativeT> - std::unique_ptr<Literal> SliceInternal( - const Shape& result_shape, absl::Span<const int64> start_indices) const; + Literal SliceInternal(const Shape& result_shape, + absl::Span<const int64> start_indices) const; }; // Abstract base class representing a mutable literal in XLA. @@ -687,8 +675,7 @@ class MutableLiteralBase : public LiteralBase { static Literal MoveIntoTuple(absl::Span<Literal> elements); // Serialize from a proto. - static StatusOr<std::unique_ptr<Literal>> CreateFromProto( - const LiteralProto& proto); + static StatusOr<Literal> CreateFromProto(const LiteralProto& proto); protected: // Returns the piece at the given ShapeIndex. @@ -1137,15 +1124,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) { } template <typename NativeT> -std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const { +Literal LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = absl::make_unique<Literal>( - ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); + Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal.shape()); if (elements == 0) { return literal; } @@ -1157,7 +1143,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const { bool done = false; while (!done) { const auto element = Get<NativeT>(input_indices); - literal->Set<NativeT>(output_indices, element); + literal.Set<NativeT>(output_indices, element); done = true; for (int n = 0; n < output_indices.size(); ++n) { |