aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-08-29 16:43:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 16:48:33 -0700
commit24af42fce7e123b162899cbcc81769f919b0029a (patch)
tree610c9b77c30c60a9d51abff94ccd0b098af1a0c4 /tensorflow/compiler/xla/tests/test_utils.cc
parent7db159e208ec65824779281277fa3b7fc4cd857b (diff)
Automated rollback of commit 065f9b833ffbb3b2f03d63febb186275674ba133
PiperOrigin-RevId: 210816152
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 776f93d9f7..60ada58b2a 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -203,6 +203,7 @@ enum class ConstantType { kUnknown, kZero, kOne };
// Return the constant type required by this computation, if known.
ConstantType GetInitValue(const HloComputation& computation) {
+ // TODO(b/77635120): Add init values, for min, max, and their arg variants.
const HloInstruction* const root = computation.root_instruction();
if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
root->operand(0)->opcode() != HloOpcode::kParameter ||
@@ -227,10 +228,10 @@ bool NeedsInitValue(const HloUse& use) {
const HloInstruction* const instruction = use.instruction;
const HloOpcode opcode = instruction->opcode();
const int64 op_num = use.operand_number;
- return (
- ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
- op_num == 1) ||
- (opcode == HloOpcode::kSelectAndScatter && op_num == 2));
+ return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
+ (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
+ (opcode == HloOpcode::kReduce &&
+ op_num >= instruction->operand_count() / 2));
}
// Generate random values that are constrained to the input_shape minus the