diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-08-29 10:34:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-29 10:38:34 -0700 |
commit | 150dee25d82589ca109957cc996efbd2a236e044 (patch) | |
tree | ef608bb703027596ed47ea4f30fa6eab93389005 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | aca93368a979419360c1fd84b53b1766b19ba81a (diff) |
[XLA] Implement variadic reduce in the evaluator. It is currently supported only for the case where all of the inputs have the same element type.
PiperOrigin-RevId: 210746149
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 776f93d9f7..60ada58b2a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -203,6 +203,7 @@ enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. ConstantType GetInitValue(const HloComputation& computation) { + // TODO(b/77635120): Add init values, for min, max, and their arg variants. const HloInstruction* const root = computation.root_instruction(); if (computation.num_parameters() != 2 || root->operand_count() != 2 || root->operand(0)->opcode() != HloOpcode::kParameter || @@ -227,10 +228,10 @@ bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - return ( - ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && - op_num == 1) || - (opcode == HloOpcode::kSelectAndScatter && op_num == 2)); + return ((opcode == HloOpcode::kReduceWindow && op_num == 1) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2) || + (opcode == HloOpcode::kReduce && + op_num >= instruction->operand_count() / 2)); } // Generate random values that are constrained to the input_shape minus the |