aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Nick Desaulniers <ndesaulniers@google.com>2017-11-15 12:36:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-15 12:41:10 -0800
commit50b1bc79f640b08633ed970719ee46c17509af98 (patch)
treeaf064eaf85535c597375401f82bc131646d456e7 /tensorflow/compiler/xla/tests/test_utils.cc
parentb0a49cd0f46cbc4d326ee87ab92c28b4b7b9ead7 (diff)
Add test util for setting init_value of SumReduce, ReduceWindow, and SelectAndScatter ops to a Constant 0.0f.
PiperOrigin-RevId: 175864310
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc69
1 files changed, 68 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index cdd3d66bbb..0d56c9f483 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -14,8 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/test_utils.h"
-
#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
namespace xla {
@@ -46,6 +47,44 @@ void PopulateWithRandomIntegralData(Literal* literal) {
}));
}
+bool LooksLikeSum(const HloInstruction& instruction) {
+ return instruction.opcode() == HloOpcode::kAdd &&
+ instruction.operand(0)->opcode() == HloOpcode::kParameter &&
+ instruction.operand(1)->opcode() == HloOpcode::kParameter &&
+ instruction.operand(0) != instruction.operand(1);
+}
+
+// Given an instruction and operand number, replace the given operand with
+// a Literal Constant Zero. Handle the case of a fusion instruction by
+// replacing the fusion's parent's parameter with a Literal Constant Zero,
+// unless the fusion's parent is itself a fusion.
+Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction,
+ const int64 operand_number) {
+ CHECK_LT(operand_number, instruction->operand_count());
+ if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) {
+ return Status::OK();
+ }
+
+ HloComputation* const computation = instruction->parent();
+ std::unique_ptr<HloInstruction> zero = HloInstruction::CreateConstant(
+ MakeUnique<Literal>(Literal::Zero(instruction->shape().element_type())));
+
+ if (computation->IsFusionComputation()) {
+ HloInstruction* const fusion_instruction = computation->FusionInstruction();
+ if (fusion_instruction->IsFused()) {
+ return Unimplemented(
+ "Unable to replace fused parameter of fusion instruction");
+ }
+ TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith(
+ instruction->operand(operand_number)->parameter_number(),
+ fusion_instruction->parent()->AddInstruction(std::move(zero))));
+ } else {
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(
+ operand_number, computation->AddInstruction(std::move(zero))));
+ }
+ return Status::OK();
+}
+
} // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
@@ -117,4 +156,32 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
return std::move(arguments);
}
+Status ReplaceInitsWithConstants(HloModule* const module) {
+ for (HloComputation* const computation : module->computations()) {
+ for (HloInstruction* const instruction : computation->instructions()) {
+ const HloOpcode opcode = instruction->opcode();
+ if ((opcode == HloOpcode::kReduce ||
+ opcode == HloOpcode::kReduceWindow) &&
+ LooksLikeSum(*instruction->to_apply()->root_instruction())) {
+ TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1));
+ } else if (opcode == HloOpcode::kSelectAndScatter &&
+ LooksLikeSum(*instruction->scatter()->root_instruction())) {
+ TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyHloModule(const perftools::gputools::Platform& platform,
+ HloModule* const module) {
+ return HloVerifier(
+ std::bind(
+ &TransferManager::GetByteSizeRequirement,
+ TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(),
+ std::placeholders::_1))
+ .Run(module)
+ .status();
+}
+
} // namespace xla