diff options
author | 2017-12-11 22:31:27 -0800 | |
---|---|---|
committer | 2017-12-11 22:35:08 -0800 | |
commit | f379d3369505457928ccda5adf72583f95709026 (patch) | |
tree | fdb9c10e2e83318423c3c71f696974134d0a5031 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | 913175c2bd38f6e97de399b29cfe1195bffbaa25 (diff) |
* HloTestBase now prints out the HLO parser error message when there is one.
* TestUtils now supports generating random literals with more than one constraint.
There is still an error if the constraints conflict.
PiperOrigin-RevId: 178720092
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 56 |
1 files changed, 37 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 780b292d1a..56859542a9 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -138,27 +138,45 @@ std::vector<HloInstruction*> FindConstrainedUses( StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses, const HloInstruction& param) { - const auto count = constrained_uses.size(); - if (count > 1) { - return Unimplemented("multiple constrained uses not yet supported"); - } + HloInstruction* needs_index = nullptr; + HloInstruction* needs_zero = nullptr; + for (HloInstruction* use : constrained_uses) { + switch (use->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + TF_RET_CHECK(ShapeUtil::Equal(param.shape(), use->operand(0)->shape())); + if (needs_index != nullptr && + !ShapeUtil::Equal(needs_index->shape(), use->shape())) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } + needs_index = use; + break; - if (count == 0) { - return MakeFakeLiteral(param.shape()); - } + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + needs_zero = use; + break; - const HloInstruction* const use = constrained_uses[0]; - switch (use->opcode()) { - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - return MakeRandomNonwrappingSliceIndex(use->operand(0)->shape(), - use->shape()); - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - return Literal::CreateFromShape(param.shape()); - default: - return Unimplemented("constrained use given; no equivalent literal"); + default: + return Unimplemented( + "Constrained operand generation not implemented for %s.", + use->ToString().c_str()); + } + } + if (needs_index != nullptr && needs_zero != nullptr) { + return Unimplemented( + "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " + "zero: %s\n", + needs_index->ToString().c_str(), needs_zero->ToString().c_str()); + } + if (needs_index != nullptr) { + return MakeRandomNonwrappingSliceIndex(param.shape(), needs_index->shape()); + } else if (needs_zero != nullptr) { + return Literal::CreateFromShape(param.shape()); + } else { + return MakeFakeLiteral(param.shape()); } } |