diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.h')
-rw-r--r-- | tensorflow/compiler/xla/literal_util.h | 228 |
1 files changed, 104 insertions, 124 deletions
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 2d6084a67a..2b181621ed 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -69,36 +69,34 @@ class LiteralUtil { // The variants not ending with WithLayout use the default XLA layout for the // literal's linear representation in memory. template <typename NativeT> - static std::unique_ptr<Literal> CreateR0(NativeT value); + static Literal CreateR0(NativeT value); template <typename NativeT> - static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values); - static std::unique_ptr<Literal> CreateR1( - const tensorflow::core::Bitmap& values); + static Literal CreateR1(absl::Span<const NativeT> values); + static Literal CreateR1(const tensorflow::core::Bitmap& values); template <typename NativeT> - static std::unique_ptr<Literal> CreateR2( + static Literal CreateR2( std::initializer_list<std::initializer_list<NativeT>> values); template <typename NativeT> - static std::unique_ptr<Literal> CreateR2WithLayout( + static Literal CreateR2WithLayout( std::initializer_list<std::initializer_list<NativeT>> values, const Layout& layout); template <typename NativeT> - static std::unique_ptr<Literal> CreateR3( - std::initializer_list< - std::initializer_list<std::initializer_list<NativeT>>> - values); + static Literal CreateR3(std::initializer_list< + std::initializer_list<std::initializer_list<NativeT>>> + values); template <typename NativeT> - static std::unique_ptr<Literal> CreateR3WithLayout( + static Literal CreateR3WithLayout( std::initializer_list< std::initializer_list<std::initializer_list<NativeT>>> values, const Layout& layout); template <typename NativeT> - static std::unique_ptr<Literal> CreateR4( + static Literal CreateR4( std::initializer_list<std::initializer_list< std::initializer_list<std::initializer_list<NativeT>>>> values); template <typename NativeT> - static std::unique_ptr<Literal> CreateR4WithLayout( + static Literal CreateR4WithLayout( std::initializer_list<std::initializer_list< std::initializer_list<std::initializer_list<NativeT>>>> values, @@ -139,9 +137,10 @@ class LiteralUtil { // [9, 10, 11]: 4.0 // template <typename NativeT> - static std::unique_ptr<Literal> CreateSparse( - absl::Span<const int64> dimensions, SparseIndexArray indices, - absl::Span<const NativeT> values, bool sort = true); + static Literal CreateSparse(absl::Span<const int64> dimensions, + SparseIndexArray indices, + absl::Span<const NativeT> values, + bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -155,130 +154,120 @@ class LiteralUtil { static Literal MaxValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template <typename NativeT> - static std::unique_ptr<Literal> CreateFullWithDescendingLayout( + static Literal CreateFullWithDescendingLayout( absl::Span<const int64> dimensions, NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear // representation in memory. template <typename NativeT> - static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values); + static Literal CreateFromArray(const Array<NativeT>& values); template <typename NativeT> - static std::unique_ptr<Literal> CreateFromArrayWithLayout( - const Array<NativeT>& values, const Layout& layout); + static Literal CreateFromArrayWithLayout(const Array<NativeT>& values, + const Layout& layout); template <typename NativeT> - static std::unique_ptr<Literal> CreateR2FromArray2D( - const Array2D<NativeT>& values); + static Literal CreateR2FromArray2D(const Array2D<NativeT>& values); template <typename NativeT> - static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout( - const Array2D<NativeT>& values, const Layout& layout); + static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values, + const Layout& layout); template <typename NativeT> - static std::unique_ptr<Literal> CreateR3FromArray3D( - const Array3D<NativeT>& values); + static Literal CreateR3FromArray3D(const Array3D<NativeT>& values); template <typename NativeT> - static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout( - const Array3D<NativeT>& values, const Layout& layout); + static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values, + const Layout& layout); template <typename NativeT> - static std::unique_ptr<Literal> CreateR4FromArray4D( - const Array4D<NativeT>& values); + static Literal CreateR4FromArray4D(const Array4D<NativeT>& values); template <typename NativeT> - static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout( - const Array4D<NativeT>& values, const Layout& layout); + static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values, + const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr<Literal> CreateR1U8(absl::string_view value); + static Literal CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. - static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols); + static Literal CreateR2F32Linspace(float from, float to, int64 rows, + int64 cols); // Creates a literal that projects the (x, y) dimensions given in values into // the z dimension given by "projection". template <typename NativeT> - static std::unique_ptr<Literal> CreateR3Projected( + static Literal CreateR3Projected( std::initializer_list<std::initializer_list<NativeT>> values, int64 projection); // Creates a literal that projects the (x, y) dimensions given in values into // the z and p dimensions given. template <typename NativeT> - static std::unique_ptr<Literal> CreateR4Projected( + static Literal CreateR4Projected( std::initializer_list<std::initializer_list<NativeT>> values, int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template <typename NativeT> - static std::unique_ptr<Literal> MakeIdentityR2(int64 size); + static Literal MakeIdentityR2(int64 size); // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. - static std::unique_ptr<Literal> MakeTuple( - absl::Span<const Literal* const> elements); + static Literal MakeTuple(absl::Span<const Literal* const> elements); - static std::unique_ptr<Literal> MakeTupleFromSlices( - absl::Span<const LiteralSlice> elements); + static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements); // As above, but intended to be invoked with move semantics; i.e. // - // std::vector<std::unique_ptr<Literal>> elements = ...; + // std::vector<Literal> elements = ...; // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. - static std::unique_ptr<Literal> MakeTupleOwned( - std::vector<std::unique_ptr<Literal>> elements); + static Literal MakeTupleOwned(std::vector<Literal> elements); - // This overload lets you pass a braced list of unique_ptr<Literal>s to + // This overload lets you pass a braced list of Literals to // MakeTupleOwned: // // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // - // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>) + // Simply relying on the MakeTupleOwned(std::vector<Literal>) // overload doesn't work because std::initializer_list's elements are always // const. // - // The arguments to this function must all be unique_ptr<Literal>. + // The arguments to this function must all be Literal. template <typename... Ts> - static std::unique_ptr<Literal> MakeTupleOwned( - std::unique_ptr<Ts>... elements) { - std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{ - std::move(elements)...}; - std::vector<std::unique_ptr<Literal>> v; + static Literal MakeTupleOwned(Ts... elements) { + std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...}; + std::vector<Literal> v; v.insert(v.begin(), std::make_move_iterator(arr.begin()), std::make_move_iterator(arr.end())); return MakeTupleOwned(std::move(v)); } // Create a constant token literal. Token types have no value. - static std::unique_ptr<Literal> CreateToken(); + static Literal CreateToken(); // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr<Literal> CreateFromDimensions( - PrimitiveType primitive_type, absl::Span<const int64> dimensions); + static Literal CreateFromDimensions(PrimitiveType primitive_type, + absl::Span<const int64> dimensions); // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr<Literal> ConvertBF16ToF32( - const LiteralSlice& bf16_literal); + static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); // If the given literal's data type is float, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr<Literal> ConvertF32ToBF16( - const LiteralSlice& f32_literal); + static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); // Creates a literal with a new shape with the given new dimensions using the // data in the given input literal. For reshaping purposes the (flat) data // buffer of the input literal is assumed to have the given minor_to_major // layout order. - static std::unique_ptr<Literal> ReshapeSlice( - absl::Span<const int64> new_dimensions, - absl::Span<const int64> minor_to_major, const LiteralSlice& literal); + static Literal ReshapeSlice(absl::Span<const int64> new_dimensions, + absl::Span<const int64> minor_to_major, + const LiteralSlice& literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -286,7 +275,7 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> - static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( + static StatusOr<Literal> CreateRandomLiteral( const Shape& shape, const std::function<T(absl::Span<const int64>)>& generator); @@ -297,8 +286,8 @@ class LiteralUtil { template < PrimitiveType type, typename E, typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> - static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); + static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine, + T mean, T stddev); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -307,8 +296,8 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> - static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); + static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean, + T stddev); // // End of factory methods. @@ -322,44 +311,43 @@ class LiteralUtil { std::ostream& operator<<(std::ostream& out, const Literal& literal); template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) { - auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape( +/* static */ Literal LiteralUtil::CreateR0(NativeT value) { + Literal literal(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType<NativeT>(), {})); - literal->Set({}, value); + literal.Set({}, value); return literal; } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1( - absl::Span<const NativeT> values) { - auto literal = absl::make_unique<Literal>( +/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) { + Literal literal( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(), {static_cast<int64>(values.size())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout( +/* static */ Literal LiteralUtil::CreateR2WithLayout( std::initializer_list<std::initializer_list<NativeT>> values, const Layout& layout) { - auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType<NativeT>(), {static_cast<int64>(values.size()), static_cast<int64>(values.begin()->size())}, AsInt64Slice(layout.minor_to_major()))); - literal->PopulateR2(values); + literal.PopulateR2(values); return literal; } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2( +/* static */ Literal LiteralUtil::CreateR2( std::initializer_list<std::initializer_list<NativeT>> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout( +/* static */ Literal LiteralUtil::CreateR3WithLayout( std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values, const Layout& layout) { @@ -384,14 +372,14 @@ template <typename NativeT> } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3( +/* static */ Literal LiteralUtil::CreateR3( std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout( +/* static */ Literal LiteralUtil::CreateR4WithLayout( std::initializer_list<std::initializer_list< std::initializer_list<std::initializer_list<NativeT>>>> values, @@ -422,23 +410,22 @@ template <typename NativeT> } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse( +/* static */ Literal LiteralUtil::CreateSparse( absl::Span<const int64> dimensions, SparseIndexArray indices, absl::Span<const NativeT> values, bool sort) { int64 num_elements = values.size(); int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = - absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType<NativeT>(), dimensions, - indices.max_indices())); - literal->PopulateSparse(indices, values, sort); + Literal literal(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType<NativeT>(), dimensions, + indices.max_indices())); + literal.PopulateSparse(indices, values, sort); return literal; } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4( +/* static */ Literal LiteralUtil::CreateR4( std::initializer_list<std::initializer_list< std::initializer_list<std::initializer_list<NativeT>>>> values) { @@ -446,50 +433,48 @@ template <typename NativeT> } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout( +/* static */ Literal LiteralUtil::CreateFromArrayWithLayout( const Array<NativeT>& values, const Layout& layout) { - auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); - literal->PopulateFromArray(values); + literal.PopulateFromArray(values); return literal; } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray( +/* static */ Literal LiteralUtil::CreateFromArray( const Array<NativeT>& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template <typename NativeT> -/* static */ std::unique_ptr<Literal> -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout( + const Array2D<NativeT>& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D( +/* static */ Literal LiteralUtil::CreateR2FromArray2D( const Array2D<NativeT>& values) { return CreateFromArray(values); } template <typename NativeT> -/* static */ std::unique_ptr<Literal> -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout( + const Array3D<NativeT>& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D( +/* static */ Literal LiteralUtil::CreateR3FromArray3D( const Array3D<NativeT>& values) { return CreateFromArray(values); } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected( +/* static */ Literal LiteralUtil::CreateR3Projected( std::initializer_list<std::initializer_list<NativeT>> values, int64 projection) { int64 dim0_size = projection; @@ -514,7 +499,7 @@ template <typename NativeT> } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected( +/* static */ Literal LiteralUtil::CreateR4Projected( std::initializer_list<std::initializer_list<NativeT>> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -542,21 +527,20 @@ template <typename NativeT> } template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D( +/* static */ Literal LiteralUtil::CreateR4FromArray4D( const Array4D<NativeT>& values) { return CreateFromArray(values); } template <typename NativeT> -/* static */ std::unique_ptr<Literal> -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout( + const Array4D<NativeT>& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } // Returns an identity matrix (rank 2) with the given row and column count. template <typename NativeT> -/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) { Array2D<NativeT> array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -565,33 +549,29 @@ template <typename NativeT> } template <typename NativeT> -/* static */ std::unique_ptr<Literal> -LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions, - NativeT value) { - auto literal = - absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout( - primitive_util::NativeToPrimitiveType<NativeT>(), dimensions)); - literal->PopulateWithValue(value); +/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout( + absl::Span<const int64> dimensions, NativeT value) { + Literal literal(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType<NativeT>(), dimensions)); + literal.PopulateWithValue(value); return literal; } template <PrimitiveType type, typename T> -/* static */ StatusOr<std::unique_ptr<Literal>> -LiteralUtil::CreateRandomLiteral( +/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral( const Shape& shape, const std::function<T(absl::Span<const int64>)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = absl::make_unique<Literal>(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>( + Literal literal(shape); + TF_RETURN_IF_ERROR(literal.Populate<NativeT>( [&](absl::Span<const int64> indexes) { return generator(indexes); })); return std::move(literal); } template <PrimitiveType type, typename E, typename T> -/* static */ StatusOr<std::unique_ptr<Literal>> -LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { +/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type; std::normal_distribution<NativeT> generator(mean, stddev); return CreateRandomLiteral<type, NativeT>( @@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, } template <PrimitiveType type, typename T> -/* static */ StatusOr<std::unique_ptr<Literal>> -LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { +/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral<type>(shape, &engine, mean, stddev); } |