diff options
author | 2017-11-08 15:35:27 -0800 | |
---|---|---|
committer | 2017-11-10 16:14:37 -0800 | |
commit | 64d2636e2946772d4b1531ec91b389110a2787b7 (patch) | |
tree | 71fdb3d248e9b1f24ab649e1d21eac490e6782bd /tensorflow/compiler/xla/literal_util.cc | |
parent | 2ba34173fad0d5b7d986baeb8171bdc6afdcd7bb (diff) |
Move MakeFakeLiteral from client/lib/testing.h to tests/test_utils.h. Also remove superfluous literal creation methods in that file, and replace them with the existing ones in the Literal class.
Also, optionally print layout in Literal::ToString.
PiperOrigin-RevId: 175076277
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index fda791401d..0cb2223ae5 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -569,9 +569,17 @@ int64 Literal::LinearIndex( return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); } -string Literal::ToString() const { +string Literal::ToString(bool print_layout) const { std::vector<string> pieces; + auto shape_to_string = [print_layout](const Shape& shape) { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(shape); + } else { + return ShapeUtil::HumanString(shape); + } + }; + auto element_to_string = [this](tensorflow::gtl::ArraySlice<int64> indices) -> string { PrimitiveType element_type = shape().element_type(); @@ -585,7 +593,7 @@ string Literal::ToString() const { // TODO(b/32894291): refactor this code to reduce code duplication. if (ShapeUtil::IsTuple(shape())) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" (\n"); pieces.push_back(tensorflow::str_util::Join( tuple_literals(), ",\n", [](string* out, const Literal& element) { @@ -601,7 +609,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 2) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(" { "); @@ -613,7 +621,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 3) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(i0 > 0 ? ",\n{" : "{"); @@ -628,7 +636,7 @@ string Literal::ToString() const { } pieces.push_back("\n}"); } else if (ShapeUtil::Rank(shape()) == 4) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); @@ -649,7 +657,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 5) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); @@ -676,7 +684,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {...}"); } |