aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal.h')
-rw-r--r--tensorflow/compiler/xla/literal.h58
1 files changed, 22 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index b928cb6374..1e0a2ad0dd 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -217,31 +217,20 @@ class LiteralBase {
// Converts this literal to the given shape. Returns an error is the
// conversion is not possible.
- //
- // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
- // instead of truncation; otherwise, truncation is used.
- //
- // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
- // the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+ StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
// Converts this literal to another primitive type using a bitcast
// conversion. The to and from primitive types must have the same bit
// width. Returns an error if the conversion is not possible. This literal
// must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
// Converts this literal to another primitive type. Returns an error if the
// conversion is not possible. This literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
- // Clones the underlying buffers into a new Literal, or new
- // std::unique_ptr<Literal>.
+ // Clones the underlying buffers into a new Literal.
Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
// TODO(b/67651157): The methods below which perform computation on Literals
// (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
@@ -259,24 +248,23 @@ class LiteralBase {
// Note: this is useful when the client wants to ensure that a value placed in
// the XLA allocation tracker has a particular layout; for efficiency
// purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
+ Literal Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
// An overload of Relayout which changes the layout of the entire shape rather
// than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+ Literal Relayout(const Shape& shape_with_layout) const;
// Creates a new literal by reshaping this literal to have the given
// dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
// This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- absl::Span<const int64> dimensions) const;
+ StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
// Creates a new literal by broadcasting this literal with `dimensions` to
// yield a literal of shape `result_shape`.
- StatusOr<std::unique_ptr<Literal>> Broadcast(
- const Shape& result_shape, absl::Span<const int64> dimensions) const;
+ StatusOr<Literal> Broadcast(const Shape& result_shape,
+ absl::Span<const int64> dimensions) const;
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
@@ -285,7 +273,7 @@ class LiteralBase {
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
// This literal must be an array.
- std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
+ Literal Transpose(absl::Span<const int64> permutation) const;
// Creates a sub-array from this literal by extracting the indices
// [start_index, limit_index) of each dimension. The result literal has the
@@ -293,15 +281,15 @@ class LiteralBase {
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
// This literal must be an array.
- std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
- absl::Span<const int64> limit_indices) const;
+ Literal Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const;
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
// literal replicated four times.
// This literal must be an array.
template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
+ Literal Replicate(int64 times) const;
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
@@ -312,7 +300,7 @@ class LiteralBase {
// initialization, then reinitialization. Conside if a call to
// absl::make_unique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+ static Literal CreateFromShape(const Shape& shape);
protected:
// A data structure representing a subshape at a particular ShapeIndex within
@@ -539,8 +527,8 @@ class LiteralBase {
private:
template <typename NativeT>
- std::unique_ptr<Literal> SliceInternal(
- const Shape& result_shape, absl::Span<const int64> start_indices) const;
+ Literal SliceInternal(const Shape& result_shape,
+ absl::Span<const int64> start_indices) const;
};
// Abstract base class representing a mutable literal in XLA.
@@ -687,8 +675,7 @@ class MutableLiteralBase : public LiteralBase {
static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
+ static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
protected:
// Returns the piece at the given ShapeIndex.
@@ -1137,15 +1124,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) {
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
+Literal LiteralBase::Replicate(int64 times) const {
DimensionVector bounds = {times};
bounds.reserve(shape().dimensions_size() + 1);
for (int64 bound : shape().dimensions()) {
bounds.push_back(bound);
}
- auto literal = absl::make_unique<Literal>(
- ShapeUtil::MakeShape(shape().element_type(), bounds));
- int64 elements = ShapeUtil::ElementsIn(literal->shape());
+ Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
+ int64 elements = ShapeUtil::ElementsIn(literal.shape());
if (elements == 0) {
return literal;
}
@@ -1157,7 +1143,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
bool done = false;
while (!done) {
const auto element = Get<NativeT>(input_indices);
- literal->Set<NativeT>(output_indices, element);
+ literal.Set<NativeT>(output_indices, element);
done = true;
for (int n = 0; n < output_indices.size(); ++n) {