aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc18
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index dd7c541733..000535a982 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -270,14 +270,22 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
- if (needs_index != nullptr &&
- !ShapeUtil::Equal(needs_index->shape(), use->shape())) {
- return Unimplemented(
- "Conflicting operand generation slice index constraints\n");
+ if (needs_index != nullptr) {
+ auto needs_index_shape = needs_index->shape();
+ auto use_shape = use->shape();
+ if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
+ needs_index_shape = needs_index->operand(0)->shape();
+ }
+ if (use->opcode() == HloOpcode::kDynamicSlice) {
+ use_shape = use->operand(0)->shape();
+ }
+ if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
+ return Unimplemented(
+ "Conflicting operand generation slice index constraints\n");
+ }
}
needs_index = use;
break;
-
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
needs_constant = use;