aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Bjarke Hammersholt Roune <broune@google.com>2017-12-11 22:31:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-11 22:35:08 -0800
commitf379d3369505457928ccda5adf72583f95709026 (patch)
treefdb9c10e2e83318423c3c71f696974134d0a5031 /tensorflow/compiler/xla/tests/test_utils.cc
parent913175c2bd38f6e97de399b29cfe1195bffbaa25 (diff)
* HloTestBase now prints out the HLO parser error message when there is one.
* TestUtils now supports generating random literals with more than one constraint. There is still an error if the constraints conflict. PiperOrigin-RevId: 178720092
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc56
1 files changed, 37 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 780b292d1a..56859542a9 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -138,27 +138,45 @@ std::vector<HloInstruction*> FindConstrainedUses(
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param) {
- const auto count = constrained_uses.size();
- if (count > 1) {
- return Unimplemented("multiple constrained uses not yet supported");
- }
+ HloInstruction* needs_index = nullptr;
+ HloInstruction* needs_zero = nullptr;
+ for (HloInstruction* use : constrained_uses) {
+ switch (use->opcode()) {
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ TF_RET_CHECK(ShapeUtil::Equal(param.shape(), use->operand(0)->shape()));
+ if (needs_index != nullptr &&
+ !ShapeUtil::Equal(needs_index->shape(), use->shape())) {
+ return Unimplemented(
+ "Conflicting operand generation slice index constraints\n");
+ }
+ needs_index = use;
+ break;
- if (count == 0) {
- return MakeFakeLiteral(param.shape());
- }
+ case HloOpcode::kReduce:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kSelectAndScatter:
+ needs_zero = use;
+ break;
- const HloInstruction* const use = constrained_uses[0];
- switch (use->opcode()) {
- case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- return MakeRandomNonwrappingSliceIndex(use->operand(0)->shape(),
- use->shape());
- case HloOpcode::kReduce:
- case HloOpcode::kReduceWindow:
- case HloOpcode::kSelectAndScatter:
- return Literal::CreateFromShape(param.shape());
- default:
- return Unimplemented("constrained use given; no equivalent literal");
+ default:
+ return Unimplemented(
+ "Constrained operand generation not implemented for %s.",
+ use->ToString().c_str());
+ }
+ }
+ if (needs_index != nullptr && needs_zero != nullptr) {
+ return Unimplemented(
+ "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
+ "zero: %s\n",
+ needs_index->ToString().c_str(), needs_zero->ToString().c_str());
+ }
+ if (needs_index != nullptr) {
+ return MakeRandomNonwrappingSliceIndex(param.shape(), needs_index->shape());
+ } else if (needs_zero != nullptr) {
+ return Literal::CreateFromShape(param.shape());
+ } else {
+ return MakeFakeLiteral(param.shape());
}
}