aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-20 12:47:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-20 12:55:39 -0700
commit5c331cfd573984287778aab02794dd86ba1f3006 (patch)
treefb36c812fbd87a51f5ecf6763461daa920aaa5bf
parentaada11e19a1ceb901f490aa3c064f2778cb2acf2 (diff)
The new array class provides a way to simplify the implementation of
these classes by eliminating a large number of duplicated code. Removing the old API is non-trivial because of the existing users outside of tensorflow. PiperOrigin-RevId: 172920837
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h41
-rw-r--r--tensorflow/compiler/xla/layout_util.cc4
-rw-r--r--tensorflow/compiler/xla/layout_util.h1
-rw-r--r--tensorflow/compiler/xla/literal_util.h121
4 files changed, 85 insertions, 82 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index cdd9c8847f..93c2a80678 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -138,6 +138,11 @@ class ComputationBuilder {
ComputationDataHandle ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
+ ComputationDataHandle ConstantFromArrayWithLayout(
+ const Array<NativeT>& values, const Layout& layout);
+ template <typename NativeT>
+ ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
ComputationDataHandle ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout);
template <typename NativeT>
@@ -910,48 +915,54 @@ ComputationDataHandle ComputationBuilder::ConstantR2(
}
template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout) {
+ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
+ const Array<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
- literal->PopulateR2FromArray2DWithLayout(values, layout);
+ literal->PopulateFromArrayWithLayout(values, layout);
});
}
template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantFromArray(
+ const Array<NativeT>& values) {
+ return ConstantOp(
+ [&values](Literal* literal) { literal->PopulateFromArray(values); });
+}
+
+template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
+ return ConstantFromArrayWithLayout(values, layout);
+}
+
+template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
const Array2D<NativeT>& values) {
- return ConstantOp(
- [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); });
+ return ConstantFromArray(values);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
- return ConstantOp([&values, &layout](Literal* literal) {
- literal->PopulateR3FromArray3DWithLayout(values, layout);
- });
+ return ConstantFromArrayWithLayout(values, layout);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
const Array3D<NativeT>& values) {
- return ConstantOp(
- [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); });
+ return ConstantFromArray(values);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
- return ConstantOp([&values, &layout](Literal* literal) {
- literal->PopulateR4FromArray4DWithLayout(values, layout);
- });
+ return ConstantFromArrayWithLayout(values, layout);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
const Array4D<NativeT>& values) {
- return ConstantOp(
- [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); });
+ return ConstantFromArray(values);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index 011fc3c194..5c2cc2a7a9 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return CreateDefaultLayoutForRank(shape.dimensions_size());
}
+/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
+ return CreateDefaultLayoutForRank(rank);
+}
+
/* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
return CreateDefaultLayoutForRank(2);
}
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index 5de0a653f6..bc42e22229 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -40,6 +40,7 @@ class LayoutUtil {
static Layout GetDefaultLayoutForShape(const Shape& shape);
// Helper functions that create default layouts for various ranks.
+ static Layout GetDefaultLayoutForRank(int64 rank);
static Layout GetDefaultLayoutForR2();
static Layout GetDefaultLayoutForR3();
static Layout GetDefaultLayoutForR4();
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index e8cee732d4..4063cb05a9 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -334,6 +334,11 @@ class Literal {
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
+ static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateFromArrayWithLayout(
+ const Array<NativeT>& values, const Layout& layout);
+ template <typename NativeT>
static std::unique_ptr<Literal> CreateR2FromArray2D(
const Array2D<NativeT>& values);
template <typename NativeT>
@@ -481,6 +486,11 @@ class Literal {
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
+ void PopulateFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
+ void PopulateFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
void PopulateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
@@ -816,33 +826,42 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout) {
+/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
+ const Array<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
- literal->PopulateR2FromArray2DWithLayout(values, layout);
+ literal->PopulateFromArrayWithLayout(values, layout);
return literal;
}
template <typename NativeT>
+/* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
+ const Array<NativeT>& values) {
+ return CreateFromArrayWithLayout(
+ values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
+ return CreateFromArrayWithLayout(values, layout);
+}
+
+template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
- return CreateR2FromArray2DWithLayout(values,
- LayoutUtil::GetDefaultLayoutForR2());
+ return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
- auto literal = MakeUnique<Literal>();
- literal->PopulateR3FromArray3DWithLayout(values, layout);
- return literal;
+ return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
- return CreateR3FromArray3DWithLayout(values,
- LayoutUtil::GetDefaultLayoutForR3());
+ return CreateFromArray(values);
}
template <typename NativeT>
@@ -901,16 +920,13 @@ template <typename NativeT>
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
- return CreateR4FromArray4DWithLayout(values,
- LayoutUtil::GetDefaultLayoutForR4());
+ return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
- auto literal = MakeUnique<Literal>();
- literal->PopulateR4FromArray4DWithLayout(values, layout);
- return literal;
+ return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
@@ -1070,82 +1086,53 @@ void Literal::PopulateR2(
}
template <typename NativeT>
-void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout) {
+void Literal::PopulateFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(),
- {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major()));
+ primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
+ AsInt64Slice(layout.minor_to_major()));
+ Reserve(values.num_elements());
+ values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value) { this->Set(indices, value); });
+}
- const int64 dim1_size = values.width();
- const int64 dim0_size = values.height();
- CHECK_EQ(dim0_size, shape().dimensions(0));
- CHECK_EQ(dim1_size, shape().dimensions(1));
- Reserve(dim1_size * dim0_size);
- for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
- for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
- Set({dim0, dim1}, values(dim0, dim1));
- }
- }
+template <typename NativeT>
+void Literal::PopulateFromArray(const Array<NativeT>& values) {
+ PopulateFromArrayWithLayout(
+ values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
+}
+
+template <typename NativeT>
+void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout) {
+ PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
- PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
+ PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
const Layout& layout) {
- *mutable_shape() = ShapeUtil::MakeShapeWithLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(),
- {values.n1(), values.n2(), values.n3()},
- AsInt64Slice(layout.minor_to_major()));
-
- CHECK_EQ(values.n1(), shape().dimensions(0));
- CHECK_EQ(values.n2(), shape().dimensions(1));
- CHECK_EQ(values.n3(), shape().dimensions(2));
- Reserve(values.n1() * values.n2() * values.n3());
- for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
- for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
- for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
- Set({dim0, dim1, dim2}, values(dim0, dim1, dim2));
- }
- }
- }
+ PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
- PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
+ PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
const Layout& layout) {
- *mutable_shape() = ShapeUtil::MakeShapeWithLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(),
- {values.planes(), values.depth(), values.height(), values.width()},
- AsInt64Slice(layout.minor_to_major()));
-
- CHECK_EQ(values.n1(), shape().dimensions(0));
- CHECK_EQ(values.n2(), shape().dimensions(1));
- CHECK_EQ(values.n3(), shape().dimensions(2));
- CHECK_EQ(values.n4(), shape().dimensions(3));
- Reserve(values.n1() * values.n2() * values.n3() * values.n4());
- for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
- for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
- for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
- for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) {
- Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3));
- }
- }
- }
- }
+ PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
- PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
+ PopulateFromArray(values);
}
template <typename NativeT, typename FnType>