diff options
author | 2018-04-05 14:54:36 -0700 | |
---|---|---|
committer | 2018-04-05 14:56:57 -0700 | |
commit | b723cea002b3e612879291faaedb13eb702a5562 (patch) | |
tree | c1ed2fd42837cc0507ecf70e7ea14d719faf7c88 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | 051dd6cf2f805f12f6e4efb5ad91d57de88481fc (diff) |
[XLA] Better support for mul reductions in MakeFakeArguments()
Mul reductions want a 1 as their init value, not a 0 or a random value.
PiperOrigin-RevId: 191802819
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 68 |
1 files changed, 47 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 821432ef7d..68f75d50cb 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -160,27 +160,38 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( return std::move(literal); } -// Matches binary addition computations. -bool LooksLikeSum(const HloComputation& computation) { +enum class ConstantType { kUnknown, kZero, kOne }; + +// Return the constant type required by this computation, if known. +ConstantType GetInitValue(const HloComputation& computation) { const HloInstruction* const root = computation.root_instruction(); - return root->opcode() == HloOpcode::kAdd && - computation.num_parameters() == 2 && - root->operand(0)->opcode() == HloOpcode::kParameter && - root->operand(1)->opcode() == HloOpcode::kParameter && - root->operand(0) != root->operand(1); + if (computation.num_parameters() != 2 || + root->operand(0)->opcode() != HloOpcode::kParameter || + root->operand(1)->opcode() != HloOpcode::kParameter || + root->operand(0) == root->operand(1)) { + return ConstantType::kUnknown; + } + + switch (root->opcode()) { + case HloOpcode::kAdd: + return ConstantType::kZero; + case HloOpcode::kMultiply: + return ConstantType::kOne; + default: + return ConstantType::kUnknown; + } } -// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, -// which requires an init_value of 0 rather than a random value. -bool NeedsZeroInitValue(const HloUse& use) { +// Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random +// initialization value. +bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; return ( ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && - op_num == 1 && LooksLikeSum(*instruction->to_apply())) || - (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && - LooksLikeSum(*instruction->scatter()))); + op_num == 1) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2)); } // Generate random values that are constrained to the input_shape minus the @@ -222,7 +233,7 @@ std::vector<HloInstruction*> FindConstrainedUses( auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), fused_uses.end()); - } else if (NeedsZeroInitValue(use)) { + } else if (NeedsInitValue(use)) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kConvert || opcode == HloOpcode::kReducePrecision) { @@ -243,7 +254,8 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { HloInstruction* needs_index = nullptr; - HloInstruction* needs_zero = nullptr; + HloInstruction* needs_constant = nullptr; + ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: @@ -258,8 +270,13 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + needs_constant = use; + constant_type = GetInitValue(*use->to_apply()); + break; + case HloOpcode::kSelectAndScatter: - needs_zero = use; + needs_constant = use; + constant_type = GetInitValue(*use->scatter()); break; default: @@ -268,17 +285,26 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( use->ToString().c_str()); } } - if (needs_index != nullptr && needs_zero != nullptr) { + if (needs_index != nullptr && needs_constant != nullptr) { return Unimplemented( "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "zero: %s\n", - needs_index->ToString().c_str(), needs_zero->ToString().c_str()); + "constant: %s\n", + needs_index->ToString().c_str(), needs_constant->ToString().c_str()); } if (needs_index != nullptr) { return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), needs_index->shape(), engine); - } else if (needs_zero != nullptr) { - return Literal::CreateFromShape(param.shape()); + } else if (needs_constant != nullptr) { + switch (constant_type) { + case ConstantType::kZero: + return Literal::Zero(param.shape().element_type()).CloneToUnique(); + case ConstantType::kOne: + return Literal::One(param.shape().element_type()).CloneToUnique(); + case ConstantType::kUnknown: + // We want the identity element for the computation, but we don't really + // know what it is - so any value we generate will be just as wrong. + return MakeFakeLiteralInternal(param.shape(), engine); + } } else { return MakeFakeLiteralInternal(param.shape(), engine); } |