diff options
author | 2017-10-20 12:47:57 -0700 | |
---|---|---|
committer | 2017-10-20 12:55:39 -0700 | |
commit | 5c331cfd573984287778aab02794dd86ba1f3006 (patch) | |
tree | fb36c812fbd87a51f5ecf6763461daa920aaa5bf | |
parent | aada11e19a1ceb901f490aa3c064f2778cb2acf2 (diff) |
The new array class provides a way to simplify the implementation of
these classes by eliminating a large number of duplicated code.
Removing the old API is non-trivial because of the existing users
outside of tensorflow.
PiperOrigin-RevId: 172920837
-rw-r--r-- | tensorflow/compiler/xla/client/computation_builder.h | 41 | ||||
-rw-r--r-- | tensorflow/compiler/xla/layout_util.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/layout_util.h | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/literal_util.h | 121 |
4 files changed, 85 insertions, 82 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index cdd9c8847f..93c2a80678 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -138,6 +138,11 @@ class ComputationBuilder { ComputationDataHandle ConstantR2( std::initializer_list<std::initializer_list<NativeT>> values); template <typename NativeT> + ComputationDataHandle ConstantFromArrayWithLayout( + const Array<NativeT>& values, const Layout& layout); + template <typename NativeT> + ComputationDataHandle ConstantFromArray(const Array<NativeT>& values); + template <typename NativeT> ComputationDataHandle ConstantR2FromArray2DWithLayout( const Array2D<NativeT>& values, const Layout& layout); template <typename NativeT> @@ -910,48 +915,54 @@ ComputationDataHandle ComputationBuilder::ConstantR2( } template <typename NativeT> -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D<NativeT>& values, const Layout& layout) { +ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( + const Array<NativeT>& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); }); } template <typename NativeT> +ComputationDataHandle ComputationBuilder::ConstantFromArray( + const Array<NativeT>& values) { + return ConstantOp( + [&values](Literal* literal) { literal->PopulateFromArray(values); }); +} + +template <typename NativeT> +ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( + const Array2D<NativeT>& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + +template <typename NativeT> ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D<NativeT>& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); + return ConstantFromArray(values); } template <typename NativeT> ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D<NativeT>& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template <typename NativeT> ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D<NativeT>& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); + return ConstantFromArray(values); } template <typename NativeT> ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D<NativeT>& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template <typename NativeT> ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D<NativeT>& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); + return ConstantFromArray(values); } } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 011fc3c194..5c2cc2a7a9 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return CreateDefaultLayoutForRank(shape.dimensions_size()); } +/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) { + return CreateDefaultLayoutForRank(rank); +} + /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { return CreateDefaultLayoutForRank(2); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 5de0a653f6..bc42e22229 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,7 @@ class LayoutUtil { static Layout GetDefaultLayoutForShape(const Shape& shape); // Helper functions that create default layouts for various ranks. + static Layout GetDefaultLayoutForRank(int64 rank); static Layout GetDefaultLayoutForR2(); static Layout GetDefaultLayoutForR3(); static Layout GetDefaultLayoutForR4(); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e8cee732d4..4063cb05a9 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -334,6 +334,11 @@ class Literal { // WithLayout use the default XLA layout for the literal's linear // representation in memory. template <typename NativeT> + static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values); + template <typename NativeT> + static std::unique_ptr<Literal> CreateFromArrayWithLayout( + const Array<NativeT>& values, const Layout& layout); + template <typename NativeT> static std::unique_ptr<Literal> CreateR2FromArray2D( const Array2D<NativeT>& values); template <typename NativeT> @@ -481,6 +486,11 @@ class Literal { std::initializer_list<std::initializer_list<NativeT>> values, const Layout& layout); template <typename NativeT> + void PopulateFromArray(const Array<NativeT>& values); + template <typename NativeT> + void PopulateFromArrayWithLayout(const Array<NativeT>& values, + const Layout& layout); + template <typename NativeT> void PopulateR2FromArray2D(const Array2D<NativeT>& values); template <typename NativeT> void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values, @@ -816,33 +826,42 @@ template <typename NativeT> } template <typename NativeT> -/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout( - const Array2D<NativeT>& values, const Layout& layout) { +/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout( + const Array<NativeT>& values, const Layout& layout) { auto literal = MakeUnique<Literal>(); - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); return literal; } template <typename NativeT> +/* static */ std::unique_ptr<Literal> Literal::CreateFromArray( + const Array<NativeT>& values) { + return CreateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template <typename NativeT> +/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout( + const Array2D<NativeT>& values, const Layout& layout) { + return CreateFromArrayWithLayout(values, layout); +} + +template <typename NativeT> /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D( const Array2D<NativeT>& values) { - return CreateR2FromArray2DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR2()); + return CreateFromArray(values); } template <typename NativeT> /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout( const Array3D<NativeT>& values, const Layout& layout) { - auto literal = MakeUnique<Literal>(); - literal->PopulateR3FromArray3DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template <typename NativeT> /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D( const Array3D<NativeT>& values) { - return CreateR3FromArray3DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR3()); + return CreateFromArray(values); } template <typename NativeT> @@ -901,16 +920,13 @@ template <typename NativeT> template <typename NativeT> /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D( const Array4D<NativeT>& values) { - return CreateR4FromArray4DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR4()); + return CreateFromArray(values); } template <typename NativeT> /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout( const Array4D<NativeT>& values, const Layout& layout) { - auto literal = MakeUnique<Literal>(); - literal->PopulateR4FromArray4DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template <typename NativeT> @@ -1070,82 +1086,53 @@ void Literal::PopulateR2( } template <typename NativeT> -void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values, - const Layout& layout) { +void Literal::PopulateFromArrayWithLayout(const Array<NativeT>& values, + const Layout& layout) { *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType<NativeT>(), - {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); + primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(), + AsInt64Slice(layout.minor_to_major())); + Reserve(values.num_elements()); + values.Each([this](tensorflow::gtl::ArraySlice<int64> indices, + NativeT value) { this->Set(indices, value); }); +} - const int64 dim1_size = values.width(); - const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - Reserve(dim1_size * dim0_size); - for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { - for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set({dim0, dim1}, values(dim0, dim1)); - } - } +template <typename NativeT> +void Literal::PopulateFromArray(const Array<NativeT>& values) { + PopulateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template <typename NativeT> +void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values, + const Layout& layout) { + PopulateFromArrayWithLayout(values, layout); } template <typename NativeT> void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); + PopulateFromArray(values); } template <typename NativeT> void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType<NativeT>(), - {values.n1(), values.n2(), values.n3()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); - } - } - } + PopulateFromArrayWithLayout(values, layout); } template <typename NativeT> void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); + PopulateFromArray(values); } template <typename NativeT> void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType<NativeT>(), - {values.planes(), values.depth(), values.height(), values.width()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - CHECK_EQ(values.n4(), shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); - } - } - } - } + PopulateFromArrayWithLayout(values, layout); } template <typename NativeT> void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); + PopulateFromArray(values); } template <typename NativeT, typename FnType> |