diff options
author | 2018-06-27 17:55:00 -0700 | |
---|---|---|
committer | 2018-06-27 17:59:28 -0700 | |
commit | 2b3b5054c7ceff0bc2811cfe0ebc063947801ce0 (patch) | |
tree | 520169f2425bceb4c800ab57d627b5a29c4358ee /tensorflow/compiler | |
parent | 5403e4b8e124d770d0988623879650baf7bba630 (diff) |
[XLA] Add test case for TOKEN constants. Make the test case pass.
PiperOrigin-RevId: 202401460
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/constants_test.cc | 9 |
3 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 7c6a181b0a..eeabf835ac 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -2142,6 +2142,7 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { } break; case TUPLE: + case TOKEN: // Nothing to do but assign the shape which is done above. return; default: @@ -2294,6 +2295,9 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto( } return Status::OK(); } + if (piece->subshape().element_type() == TOKEN) { + return Status::OK(); + } CHECK(ShapeUtil::IsArray(piece->subshape())); TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4858fe61e0..48fd07371d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -530,6 +530,10 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { constant, BuildTupleConstant(computation_, constant->literal())); } + if (constant->shape().element_type() == TOKEN) { + return Status::OK(); + } + // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 1786cf7359..8bfc19cbcc 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -171,5 +172,13 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } +TEST_F(ConstantsTest, Token) { + XlaBuilder builder(TestName()); + ConstantLiteral(&builder, *Literal::CreateToken()); + // TODO(b/80000000): tokens cannot be returned from computations. + Tuple(&builder, {}); + TF_ASSERT_OK(Execute(&builder, {}).status()); +} + } // namespace } // namespace xla |