aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-04-05 14:54:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-05 14:56:57 -0700
commitb723cea002b3e612879291faaedb13eb702a5562 (patch)
treec1ed2fd42837cc0507ecf70e7ea14d719faf7c88 /tensorflow/compiler/xla/tests/test_utils.cc
parent051dd6cf2f805f12f6e4efb5ad91d57de88481fc (diff)
[XLA] Better support for mul reductions in MakeFakeArguments()
Mul reductions want a 1 as their init value, not a 0 or a random value. PiperOrigin-RevId: 191802819
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc68
1 files changed, 47 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 821432ef7d..68f75d50cb 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -160,27 +160,38 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
return std::move(literal);
}
-// Matches binary addition computations.
-bool LooksLikeSum(const HloComputation& computation) {
+enum class ConstantType { kUnknown, kZero, kOne };
+
+// Return the constant type required by this computation, if known.
+ConstantType GetInitValue(const HloComputation& computation) {
const HloInstruction* const root = computation.root_instruction();
- return root->opcode() == HloOpcode::kAdd &&
- computation.num_parameters() == 2 &&
- root->operand(0)->opcode() == HloOpcode::kParameter &&
- root->operand(1)->opcode() == HloOpcode::kParameter &&
- root->operand(0) != root->operand(1);
+ if (computation.num_parameters() != 2 ||
+ root->operand(0)->opcode() != HloOpcode::kParameter ||
+ root->operand(1)->opcode() != HloOpcode::kParameter ||
+ root->operand(0) == root->operand(1)) {
+ return ConstantType::kUnknown;
+ }
+
+ switch (root->opcode()) {
+ case HloOpcode::kAdd:
+ return ConstantType::kZero;
+ case HloOpcode::kMultiply:
+ return ConstantType::kOne;
+ default:
+ return ConstantType::kUnknown;
+ }
}
-// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition,
-// which requires an init_value of 0 rather than a random value.
-bool NeedsZeroInitValue(const HloUse& use) {
+// Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
+// initialization value.
+bool NeedsInitValue(const HloUse& use) {
const HloInstruction* const instruction = use.instruction;
const HloOpcode opcode = instruction->opcode();
const int64 op_num = use.operand_number;
return (
((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
- op_num == 1 && LooksLikeSum(*instruction->to_apply())) ||
- (opcode == HloOpcode::kSelectAndScatter && op_num == 2 &&
- LooksLikeSum(*instruction->scatter())));
+ op_num == 1) ||
+ (opcode == HloOpcode::kSelectAndScatter && op_num == 2));
}
// Generate random values that are constrained to the input_shape minus the
@@ -222,7 +233,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
fused_uses.end());
- } else if (NeedsZeroInitValue(use)) {
+ } else if (NeedsInitValue(use)) {
constrained_uses.push_back(instruction);
} else if (opcode == HloOpcode::kConvert ||
opcode == HloOpcode::kReducePrecision) {
@@ -243,7 +254,8 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
HloInstruction* needs_index = nullptr;
- HloInstruction* needs_zero = nullptr;
+ HloInstruction* needs_constant = nullptr;
+ ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
@@ -258,8 +270,13 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
+ needs_constant = use;
+ constant_type = GetInitValue(*use->to_apply());
+ break;
+
case HloOpcode::kSelectAndScatter:
- needs_zero = use;
+ needs_constant = use;
+ constant_type = GetInitValue(*use->scatter());
break;
default:
@@ -268,17 +285,26 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
use->ToString().c_str());
}
}
- if (needs_index != nullptr && needs_zero != nullptr) {
+ if (needs_index != nullptr && needs_constant != nullptr) {
return Unimplemented(
"Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
- "zero: %s\n",
- needs_index->ToString().c_str(), needs_zero->ToString().c_str());
+ "constant: %s\n",
+ needs_index->ToString().c_str(), needs_constant->ToString().c_str());
}
if (needs_index != nullptr) {
return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
needs_index->shape(), engine);
- } else if (needs_zero != nullptr) {
- return Literal::CreateFromShape(param.shape());
+ } else if (needs_constant != nullptr) {
+ switch (constant_type) {
+ case ConstantType::kZero:
+ return Literal::Zero(param.shape().element_type()).CloneToUnique();
+ case ConstantType::kOne:
+ return Literal::One(param.shape().element_type()).CloneToUnique();
+ case ConstantType::kUnknown:
+ // We want the identity element for the computation, but we don't really
+ // know what it is - so any value we generate will be just as wrong.
+ return MakeFakeLiteralInternal(param.shape(), engine);
+ }
} else {
return MakeFakeLiteralInternal(param.shape(), engine);
}