aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-27 17:55:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 17:59:28 -0700
commit2b3b5054c7ceff0bc2811cfe0ebc063947801ce0 (patch)
tree520169f2425bceb4c800ab57d627b5a29c4358ee /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parent5403e4b8e124d770d0988623879650baf7bba630 (diff)
[XLA] Add test case for TOKEN constants. Make the test case pass.
PiperOrigin-RevId: 202401460
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4858fe61e0..48fd07371d 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -530,6 +530,10 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
constant, BuildTupleConstant(computation_, constant->literal()));
}
+ if (constant->shape().element_type() == TOKEN) {
+ return Status::OK();
+ }
+
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {