diff options
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index dd7c541733..000535a982 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -270,14 +270,22 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( switch (use->opcode()) { case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr && - !ShapeUtil::Equal(needs_index->shape(), use->shape())) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + if (needs_index != nullptr) { + auto needs_index_shape = needs_index->shape(); + auto use_shape = use->shape(); + if (needs_index->opcode() == HloOpcode::kDynamicSlice) { + needs_index_shape = needs_index->operand(0)->shape(); + } + if (use->opcode() == HloOpcode::kDynamicSlice) { + use_shape = use->operand(0)->shape(); + } + if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } } needs_index = use; break; - case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = use; |