diff options
author | 2017-11-15 12:36:51 -0800 | |
---|---|---|
committer | 2017-11-15 12:41:10 -0800 | |
commit | 50b1bc79f640b08633ed970719ee46c17509af98 (patch) | |
tree | af064eaf85535c597375401f82bc131646d456e7 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | b0a49cd0f46cbc4d326ee87ab92c28b4b7b9ead7 (diff) |
Add test util for setting init_value of SumReduce, ReduceWindow, and SelectAndScatter ops to a Constant 0.0f.
PiperOrigin-RevId: 175864310
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 69 |
1 files changed, 68 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index cdd3d66bbb..0d56c9f483 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/test_utils.h" - #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" namespace xla { @@ -46,6 +47,44 @@ void PopulateWithRandomIntegralData(Literal* literal) { })); } +bool LooksLikeSum(const HloInstruction& instruction) { + return instruction.opcode() == HloOpcode::kAdd && + instruction.operand(0)->opcode() == HloOpcode::kParameter && + instruction.operand(1)->opcode() == HloOpcode::kParameter && + instruction.operand(0) != instruction.operand(1); +} + +// Given an instruction and operand number, replace the given operand with +// a Literal Constant Zero. Handle the case of a fusion instruction by +// replacing the fusion's parent's parameter with a Literal Constant Zero, +// unless the fusion's parent is itself a fusion. +Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction, + const int64 operand_number) { + CHECK_LT(operand_number, instruction->operand_count()); + if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) { + return Status::OK(); + } + + HloComputation* const computation = instruction->parent(); + std::unique_ptr<HloInstruction> zero = HloInstruction::CreateConstant( + MakeUnique<Literal>(Literal::Zero(instruction->shape().element_type()))); + + if (computation->IsFusionComputation()) { + HloInstruction* const fusion_instruction = computation->FusionInstruction(); + if (fusion_instruction->IsFused()) { + return Unimplemented( + "Unable to replace fused parameter of fusion instruction"); + } + TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith( + instruction->operand(operand_number)->parameter_number(), + fusion_instruction->parent()->AddInstruction(std::move(zero)))); + } else { + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith( + operand_number, computation->AddInstruction(std::move(zero)))); + } + return Status::OK(); +} + } // namespace StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { @@ -117,4 +156,32 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( return std::move(arguments); } +Status ReplaceInitsWithConstants(HloModule* const module) { + for (HloComputation* const computation : module->computations()) { + for (HloInstruction* const instruction : computation->instructions()) { + const HloOpcode opcode = instruction->opcode(); + if ((opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow) && + LooksLikeSum(*instruction->to_apply()->root_instruction())) { + TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1)); + } else if (opcode == HloOpcode::kSelectAndScatter && + LooksLikeSum(*instruction->scatter()->root_instruction())) { + TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2)); + } + } + } + return Status::OK(); +} + +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module) { + return HloVerifier( + std::bind( + &TransferManager::GetByteSizeRequirement, + TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(), + std::placeholders::_1)) + .Run(module) + .status(); +} + } // namespace xla |