aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_util.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-11-08 15:35:27 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:37 -0800
commit64d2636e2946772d4b1531ec91b389110a2787b7 (patch)
tree71fdb3d248e9b1f24ab649e1d21eac490e6782bd /tensorflow/compiler/xla/literal_util.cc
parent2ba34173fad0d5b7d986baeb8171bdc6afdcd7bb (diff)
Move MakeFakeLiteral from client/lib/testing.h to tests/test_utils.h. Also remove superfluous literal creation methods in that file, and replace them with the existing ones in the Literal class.
Also, optionally print layout in Literal::ToString. PiperOrigin-RevId: 175076277
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.cc')
-rw-r--r--tensorflow/compiler/xla/literal_util.cc22
1 files changed, 15 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index fda791401d..0cb2223ae5 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -569,9 +569,17 @@ int64 Literal::LinearIndex(
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
}
-string Literal::ToString() const {
+string Literal::ToString(bool print_layout) const {
std::vector<string> pieces;
+ auto shape_to_string = [print_layout](const Shape& shape) {
+ if (print_layout) {
+ return ShapeUtil::HumanStringWithLayout(shape);
+ } else {
+ return ShapeUtil::HumanString(shape);
+ }
+ };
+
auto element_to_string =
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
PrimitiveType element_type = shape().element_type();
@@ -585,7 +593,7 @@ string Literal::ToString() const {
// TODO(b/32894291): refactor this code to reduce code duplication.
if (ShapeUtil::IsTuple(shape())) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" (\n");
pieces.push_back(tensorflow::str_util::Join(
tuple_literals(), ",\n", [](string* out, const Literal& element) {
@@ -601,7 +609,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 2) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(" { ");
@@ -613,7 +621,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 3) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(i0 > 0 ? ",\n{" : "{");
@@ -628,7 +636,7 @@ string Literal::ToString() const {
}
pieces.push_back("\n}");
} else if (ShapeUtil::Rank(shape()) == 4) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@@ -649,7 +657,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 5) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@@ -676,7 +684,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {...}");
}