aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-06-20 18:33:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 18:36:07 -0700
commite8b18a6f0c02d364ff47ba5fa3dc61458d273674 (patch)
tree180e6c6eadbfca62dd7fcdb88d53248a6347f573 /tensorflow/compiler/xla/tests/test_utils.cc
parent740966e69e87eaee37161efc96d8ea04162e1844 (diff)
Fix a bug in test_util when generating index for dynamic slice
dynamic slice's index space should be it's first operand's shape. PiperOrigin-RevId: 201454414
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc18
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index dd7c541733..000535a982 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -270,14 +270,22 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
- if (needs_index != nullptr &&
- !ShapeUtil::Equal(needs_index->shape(), use->shape())) {
- return Unimplemented(
- "Conflicting operand generation slice index constraints\n");
+ if (needs_index != nullptr) {
+ auto needs_index_shape = needs_index->shape();
+ auto use_shape = use->shape();
+ if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
+ needs_index_shape = needs_index->operand(0)->shape();
+ }
+ if (use->opcode() == HloOpcode::kDynamicSlice) {
+ use_shape = use->operand(0)->shape();
+ }
+ if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
+ return Unimplemented(
+ "Conflicting operand generation slice index constraints\n");
+ }
}
needs_index = use;
break;
-
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
needs_constant = use;