aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc69
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h11
3 files changed, 82 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 3e62481629..63c3541e14 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -69,7 +69,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_headers_lib",
],
)
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index cdd3d66bbb..0d56c9f483 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -14,8 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/test_utils.h"
-
#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
namespace xla {
@@ -46,6 +47,44 @@ void PopulateWithRandomIntegralData(Literal* literal) {
}));
}
+bool LooksLikeSum(const HloInstruction& instruction) {
+ return instruction.opcode() == HloOpcode::kAdd &&
+ instruction.operand(0)->opcode() == HloOpcode::kParameter &&
+ instruction.operand(1)->opcode() == HloOpcode::kParameter &&
+ instruction.operand(0) != instruction.operand(1);
+}
+
+// Given an instruction and operand number, replace the given operand with
+// a Literal Constant Zero. Handle the case of a fusion instruction by
+// replacing the fusion's parent's parameter with a Literal Constant Zero,
+// unless the fusion's parent is itself a fusion.
+Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction,
+ const int64 operand_number) {
+ CHECK_LT(operand_number, instruction->operand_count());
+ if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) {
+ return Status::OK();
+ }
+
+ HloComputation* const computation = instruction->parent();
+ std::unique_ptr<HloInstruction> zero = HloInstruction::CreateConstant(
+ MakeUnique<Literal>(Literal::Zero(instruction->shape().element_type())));
+
+ if (computation->IsFusionComputation()) {
+ HloInstruction* const fusion_instruction = computation->FusionInstruction();
+ if (fusion_instruction->IsFused()) {
+ return Unimplemented(
+ "Unable to replace fused parameter of fusion instruction");
+ }
+ TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith(
+ instruction->operand(operand_number)->parameter_number(),
+ fusion_instruction->parent()->AddInstruction(std::move(zero))));
+ } else {
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(
+ operand_number, computation->AddInstruction(std::move(zero))));
+ }
+ return Status::OK();
+}
+
} // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
@@ -117,4 +156,32 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
return std::move(arguments);
}
+Status ReplaceInitsWithConstants(HloModule* const module) {
+ for (HloComputation* const computation : module->computations()) {
+ for (HloInstruction* const instruction : computation->instructions()) {
+ const HloOpcode opcode = instruction->opcode();
+ if ((opcode == HloOpcode::kReduce ||
+ opcode == HloOpcode::kReduceWindow) &&
+ LooksLikeSum(*instruction->to_apply()->root_instruction())) {
+ TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1));
+ } else if (opcode == HloOpcode::kSelectAndScatter &&
+ LooksLikeSum(*instruction->scatter()->root_instruction())) {
+ TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyHloModule(const perftools::gputools::Platform& platform,
+ HloModule* const module) {
+ return HloVerifier(
+ std::bind(
+ &TransferManager::GetByteSizeRequirement,
+ TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(),
+ std::placeholders::_1))
+ .Run(module)
+ .status();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 12d5255fce..9aca162a18 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/platform.h"
namespace xla {
@@ -62,6 +63,16 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
const HloModule& module);
+// Reductions using Adds, ReduceWindow, and SelectAndScatter, require their
+// init_value to be replaced with the constant 0.0f when testing, otherwise we
+// may generate a bad init_value when looking at the op in isolation.
+Status ReplaceInitsWithConstants(HloModule* const module);
+
+// Check that a given module satisfies various constraints before trying to
+// execute it.
+Status VerifyHloModule(const perftools::gputools::Platform& platform,
+ HloModule* const module);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_