From 5c331cfd573984287778aab02794dd86ba1f3006 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 12:47:57 -0700 Subject: The new array class provides a way to simplify the implementation of these classes by eliminating a large number of duplicated code. Removing the old API is non-trivial because of the existing users outside of tensorflow. PiperOrigin-RevId: 172920837 --- .../compiler/xla/client/computation_builder.h | 41 ++++--- tensorflow/compiler/xla/layout_util.cc | 4 + tensorflow/compiler/xla/layout_util.h | 1 + tensorflow/compiler/xla/literal_util.h | 121 +++++++++------------ 4 files changed, 85 insertions(+), 82 deletions(-) diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index cdd9c8847f..93c2a80678 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -138,6 +138,11 @@ class ComputationBuilder { ComputationDataHandle ConstantR2( std::initializer_list> values); template + ComputationDataHandle ConstantFromArrayWithLayout( + const Array& values, const Layout& layout); + template + ComputationDataHandle ConstantFromArray(const Array& values); + template ComputationDataHandle ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout); template @@ -910,48 +915,54 @@ ComputationDataHandle ComputationBuilder::ConstantR2( } template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( + const Array& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); }); } +template +ComputationDataHandle ComputationBuilder::ConstantFromArray( + const Array& values) { + return ConstantOp( + [&values](Literal* literal) { literal->PopulateFromArray(values); }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); + return ConstantFromArray(values); } } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 011fc3c194..5c2cc2a7a9 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return CreateDefaultLayoutForRank(shape.dimensions_size()); } +/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) { + return CreateDefaultLayoutForRank(rank); +} + /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { return CreateDefaultLayoutForRank(2); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 5de0a653f6..bc42e22229 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,7 @@ class LayoutUtil { static Layout GetDefaultLayoutForShape(const Shape& shape); // Helper functions that create default layouts for various ranks. + static Layout GetDefaultLayoutForRank(int64 rank); static Layout GetDefaultLayoutForR2(); static Layout GetDefaultLayoutForR3(); static Layout GetDefaultLayoutForR4(); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e8cee732d4..4063cb05a9 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -334,6 +334,11 @@ class Literal { // WithLayout use the default XLA layout for the literal's linear // representation in memory. template + static std::unique_ptr CreateFromArray(const Array& values); + template + static std::unique_ptr CreateFromArrayWithLayout( + const Array& values, const Layout& layout); + template static std::unique_ptr CreateR2FromArray2D( const Array2D& values); template @@ -481,6 +486,11 @@ class Literal { std::initializer_list> values, const Layout& layout); template + void PopulateFromArray(const Array& values); + template + void PopulateFromArrayWithLayout(const Array& values, + const Layout& layout); + template void PopulateR2FromArray2D(const Array2D& values); template void PopulateR2FromArray2DWithLayout(const Array2D& values, @@ -816,33 +826,42 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( + const Array& values, const Layout& layout) { auto literal = MakeUnique(); - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); return literal; } +template +/* static */ std::unique_ptr Literal::CreateFromArray( + const Array& values) { + return CreateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template +/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return CreateFromArrayWithLayout(values, layout); +} + template /* static */ std::unique_ptr Literal::CreateR2FromArray2D( const Array2D& values) { - return CreateR2FromArray2DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR2()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR3FromArray3DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3D( const Array3D& values) { - return CreateR3FromArray3DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR3()); + return CreateFromArray(values); } template @@ -901,16 +920,13 @@ template template /* static */ std::unique_ptr Literal::CreateR4FromArray4D( const Array4D& values) { - return CreateR4FromArray4DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR4()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR4FromArray4DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template @@ -1070,82 +1086,53 @@ void Literal::PopulateR2( } template -void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +void Literal::PopulateFromArrayWithLayout(const Array& values, + const Layout& layout) { *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); + primitive_util::NativeToPrimitiveType(), values.dimensions(), + AsInt64Slice(layout.minor_to_major())); + Reserve(values.num_elements()); + values.Each([this](tensorflow::gtl::ArraySlice indices, + NativeT value) { this->Set(indices, value); }); +} - const int64 dim1_size = values.width(); - const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - Reserve(dim1_size * dim0_size); - for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { - for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set({dim0, dim1}, values(dim0, dim1)); - } - } +template +void Literal::PopulateFromArray(const Array& values) { + PopulateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template +void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR2FromArray2D(const Array2D& values) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); + PopulateFromArray(values); } template void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.n1(), values.n2(), values.n3()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR3FromArray3D(const Array3D& values) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); + PopulateFromArray(values); } template void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.planes(), values.depth(), values.height(), values.width()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - CHECK_EQ(values.n4(), shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); - } - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); + PopulateFromArray(values); } template -- cgit v1.2.3