aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.h')
-rw-r--r--tensorflow/compiler/xla/literal_util.h228
1 files changed, 104 insertions, 124 deletions
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 2d6084a67a..2b181621ed 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -69,36 +69,34 @@ class LiteralUtil {
// The variants not ending with WithLayout use the default XLA layout for the
// literal's linear representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR0(NativeT value);
+ static Literal CreateR0(NativeT value);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
- static std::unique_ptr<Literal> CreateR1(
- const tensorflow::core::Bitmap& values);
+ static Literal CreateR1(absl::Span<const NativeT> values);
+ static Literal CreateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2(
+ static Literal CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2WithLayout(
+ static Literal CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3(
- std::initializer_list<
- std::initializer_list<std::initializer_list<NativeT>>>
- values);
+ static Literal CreateR3(std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3WithLayout(
+ static Literal CreateR3WithLayout(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4(
+ static Literal CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4WithLayout(
+ static Literal CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -139,9 +137,10 @@ class LiteralUtil {
// [9, 10, 11]: 4.0
//
template <typename NativeT>
- static std::unique_ptr<Literal> CreateSparse(
- absl::Span<const int64> dimensions, SparseIndexArray indices,
- absl::Span<const NativeT> values, bool sort = true);
+ static Literal CreateSparse(absl::Span<const int64> dimensions,
+ SparseIndexArray indices,
+ absl::Span<const NativeT> values,
+ bool sort = true);
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
@@ -155,130 +154,120 @@ class LiteralUtil {
static Literal MaxValue(PrimitiveType primitive_type);
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
+ static Literal CreateFullWithDescendingLayout(
absl::Span<const int64> dimensions, NativeT value);
// Creates a new literal from an Array type. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
+ static Literal CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout);
+ static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2D(
- const Array2D<NativeT>& values);
+ static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout);
+ static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3D(
- const Array3D<NativeT>& values);
+ static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout);
+ static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4D(
- const Array4D<NativeT>& values);
+ static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout);
+ static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout);
// Creates a new vector of U8s literal value from a string.
- static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
+ static Literal CreateR1U8(absl::string_view value);
// Creates a linspace-populated literal with the given number of rows and
// columns.
- static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
- int64 rows, int64 cols);
+ static Literal CreateR2F32Linspace(float from, float to, int64 rows,
+ int64 cols);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z dimension given by "projection".
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3Projected(
+ static Literal CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z and p dimensions given.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4Projected(
+ static Literal CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z);
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
- static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
+ static Literal MakeIdentityR2(int64 size);
// Returns a tuple literal composed of given literals. Data is copied from the
// given elements into the returned literal.
- static std::unique_ptr<Literal> MakeTuple(
- absl::Span<const Literal* const> elements);
+ static Literal MakeTuple(absl::Span<const Literal* const> elements);
- static std::unique_ptr<Literal> MakeTupleFromSlices(
- absl::Span<const LiteralSlice> elements);
+ static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
// As above, but intended to be invoked with move semantics; i.e.
//
- // std::vector<std::unique_ptr<Literal>> elements = ...;
+ // std::vector<Literal> elements = ...;
// auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
//
// This would have been declared as an overload, but there is ambiguity
// in invocation between the above signature and this one.
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements);
+ static Literal MakeTupleOwned(std::vector<Literal> elements);
- // This overload lets you pass a braced list of unique_ptr<Literal>s to
+ // This overload lets you pass a braced list of Literals to
// MakeTupleOwned:
//
// LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
//
- // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
+ // Simply relying on the MakeTupleOwned(std::vector<Literal>)
// overload doesn't work because std::initializer_list's elements are always
// const.
//
- // The arguments to this function must all be unique_ptr<Literal>.
+ // The arguments to this function must all be Literal.
template <typename... Ts>
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::unique_ptr<Ts>... elements) {
- std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
- std::move(elements)...};
- std::vector<std::unique_ptr<Literal>> v;
+ static Literal MakeTupleOwned(Ts... elements) {
+ std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
+ std::vector<Literal> v;
v.insert(v.begin(), std::make_move_iterator(arr.begin()),
std::make_move_iterator(arr.end()));
return MakeTupleOwned(std::move(v));
}
// Create a constant token literal. Token types have no value.
- static std::unique_ptr<Literal> CreateToken();
+ static Literal CreateToken();
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromDimensions(
- PrimitiveType primitive_type, absl::Span<const int64> dimensions);
+ static Literal CreateFromDimensions(PrimitiveType primitive_type,
+ absl::Span<const int64> dimensions);
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertBF16ToF32(
- const LiteralSlice& bf16_literal);
+ static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertF32ToBF16(
- const LiteralSlice& f32_literal);
+ static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
// Creates a literal with a new shape with the given new dimensions using the
// data in the given input literal. For reshaping purposes the (flat) data
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
- static std::unique_ptr<Literal> ReshapeSlice(
- absl::Span<const int64> new_dimensions,
- absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
+ static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
+ absl::Span<const int64> minor_to_major,
+ const LiteralSlice& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
@@ -286,7 +275,7 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ static StatusOr<Literal> CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator);
@@ -297,8 +286,8 @@ class LiteralUtil {
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
+ T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
@@ -307,8 +296,8 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
+ T stddev);
//
// End of factory methods.
@@ -322,44 +311,43 @@ class LiteralUtil {
std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
+/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
- literal->Set({}, value);
+ literal.Set({}, value);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
- absl::Span<const NativeT> values) {
- auto literal = absl::make_unique<Literal>(
+/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
+ Literal literal(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
+/* static */ Literal LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size()),
static_cast<int64>(values.begin()->size())},
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateR2(values);
+ literal.PopulateR2(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
+/* static */ Literal LiteralUtil::CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
+/* static */ Literal LiteralUtil::CreateR3WithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
@@ -384,14 +372,14 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
+/* static */ Literal LiteralUtil::CreateR3(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values) {
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
+/* static */ Literal LiteralUtil::CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -422,23 +410,22 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
+/* static */ Literal LiteralUtil::CreateSparse(
absl::Span<const int64> dimensions, SparseIndexArray indices,
absl::Span<const NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
- indices.max_indices()));
- literal->PopulateSparse(indices, values, sort);
+ Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
+ indices.max_indices()));
+ literal.PopulateSparse(indices, values, sort);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
+/* static */ Literal LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
@@ -446,50 +433,48 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
+/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateFromArray(values);
+ literal.PopulateFromArray(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
+/* static */ Literal LiteralUtil::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
+/* static */ Literal LiteralUtil::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
+/* static */ Literal LiteralUtil::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
+/* static */ Literal LiteralUtil::CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
int64 dim0_size = projection;
@@ -514,7 +499,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
+/* static */ Literal LiteralUtil::CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
int64 dim0_size = projection_p;
@@ -542,21 +527,20 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
+/* static */ Literal LiteralUtil::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
+/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
Array2D<NativeT> array(size, size, 0);
for (int64 i = 0; i < size; ++i) {
array(i, i) = 1;
@@ -565,33 +549,29 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
- NativeT value) {
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
- literal->PopulateWithValue(value);
+/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
+ absl::Span<const int64> dimensions, NativeT value) {
+ Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
+ literal.PopulateWithValue(value);
return literal;
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
- auto literal = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
+ Literal literal(shape);
+ TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
[&](absl::Span<const int64> indexes) { return generator(indexes); }));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
- T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
@@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}