aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-27 17:55:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 17:59:28 -0700
commit2b3b5054c7ceff0bc2811cfe0ebc063947801ce0 (patch)
tree520169f2425bceb4c800ab57d627b5a29c4358ee /tensorflow/compiler
parent5403e4b8e124d770d0988623879650baf7bba630 (diff)
[XLA] Add test case for TOKEN constants. Make the test case pass.
PiperOrigin-RevId: 202401460
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/literal_util.cc4
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc4
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc9
3 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 7c6a181b0a..eeabf835ac 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -2142,6 +2142,7 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
}
break;
case TUPLE:
+ case TOKEN:
// Nothing to do but assign the shape which is done above.
return;
default:
@@ -2294,6 +2295,9 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
}
return Status::OK();
}
+ if (piece->subshape().element_type() == TOKEN) {
+ return Status::OK();
+ }
CHECK(ShapeUtil::IsArray(piece->subshape()));
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4858fe61e0..48fd07371d 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -530,6 +530,10 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
constant, BuildTupleConstant(computation_, constant->literal()));
}
+ if (constant->shape().element_type() == TOKEN) {
+ return Status::OK();
+ }
+
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 1786cf7359..8bfc19cbcc 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -171,5 +172,13 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) {
{2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_);
}
+TEST_F(ConstantsTest, Token) {
+ XlaBuilder builder(TestName());
+ ConstantLiteral(&builder, *Literal::CreateToken());
+ // TODO(b/80000000): tokens cannot be returned from computations.
+ Tuple(&builder, {});
+ TF_ASSERT_OK(Execute(&builder, {}).status());
+}
+
} // namespace
} // namespace xla