aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/llvm_compiler_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/llvm_compiler_test.cc143
1 files changed, 37 insertions, 106 deletions
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 70d8b764a3..458258e7ee 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -14,118 +14,49 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
-#include "tensorflow/compiler/xla/service/backend.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/platform_util.h"
-#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/test.h"
-#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
namespace {
-class LLVMCompilerTest : public ::testing::Test {
- public:
- void SetUp() override {
- Platform *platform = FindPlatform();
- ASSERT_NE(platform, nullptr);
-
- BackendOptions backend_options;
- backend_options.set_platform(platform);
- StatusOr<std::unique_ptr<Backend>> backend_or_status =
- Backend::CreateBackend(backend_options);
- ASSERT_IS_OK(backend_or_status.status());
- backend_ = backend_or_status.ConsumeValueOrDie();
- }
-
- ~LLVMCompilerTest() override {}
-
- protected:
- using Platform = ::perftools::gputools::Platform;
-
- explicit LLVMCompilerTest(string platform_name)
- : platform_name_(std::move(platform_name)) {}
-
- void TestCompilerHooks(LLVMCompiler *compiler) {
- int pre_opt_hook_call_count = 0;
- int post_opt_hook_call_count = 0;
-
- auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) {
- ++pre_opt_hook_call_count;
- return Status::OK();
- };
- auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) {
- ++post_opt_hook_call_count;
- return Status::OK();
- };
-
- // Create HLO module, and run the compiler.
- auto builder = HloComputation::Builder(TestName());
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
-
- auto hlo_module = CreateNewModule();
- hlo_module->AddEntryComputation(builder.Build());
-
- compiler->SetPreOptimizationHook(pre_opt_hook);
- compiler->SetPostOptimizationHook(post_opt_hook);
-
- ASSERT_TRUE(compiler
- ->Compile(std::move(hlo_module),
- backend_->default_stream_executor())
- .ok());
-
- // Test that hooks were called.
- EXPECT_EQ(1, pre_opt_hook_call_count);
- EXPECT_EQ(1, post_opt_hook_call_count);
- }
-
- private:
- Platform *FindPlatform() {
- for (Platform *platform :
- PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) {
- if (platform->Name() == platform_name_) {
- return platform;
- }
- }
- return nullptr;
- }
-
- string platform_name_;
- std::unique_ptr<Backend> backend_;
-
- static string TestName() {
- return ::testing::UnitTest::GetInstance()->current_test_info()->name();
- }
-
- static std::unique_ptr<HloModule> CreateNewModule() {
- HloModuleConfig config;
- config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
- config);
- }
-};
-
-class CpuCompilerTest : public LLVMCompilerTest {
- public:
- CpuCompilerTest() : LLVMCompilerTest("Host") {}
-};
-
-class GpuCompilerTest : public LLVMCompilerTest {
- public:
- GpuCompilerTest() : LLVMCompilerTest("CUDA") {}
-};
-
-TEST_F(CpuCompilerTest, HooksTest) {
- cpu::CpuCompiler compiler;
- TestCompilerHooks(&compiler);
-}
-
-TEST_F(GpuCompilerTest, HooksTest) {
- gpu::GpuCompiler compiler;
- TestCompilerHooks(&compiler);
+class LLVMCompilerTest : public HloTestBase {};
+
+XLA_TEST_F(LLVMCompilerTest, CompilerHooks) {
+ int pre_opt_hook_call_count = 0;
+ int post_opt_hook_call_count = 0;
+
+ auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) {
+ ++pre_opt_hook_call_count;
+ return Status::OK();
+ };
+ auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) {
+ ++post_opt_hook_call_count;
+ return Status::OK();
+ };
+
+ // Create HLO module, and run the compiler.
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(builder.Build());
+
+ auto compiler = static_cast<LLVMCompiler *>(backend().compiler());
+ compiler->SetPreOptimizationHook(pre_opt_hook);
+ compiler->SetPostOptimizationHook(post_opt_hook);
+
+ ASSERT_TRUE(
+ compiler
+ ->Compile(std::move(hlo_module), backend().default_stream_executor())
+ .ok());
+
+ // Test that hooks were called.
+ EXPECT_EQ(1, pre_opt_hook_call_count);
+ EXPECT_EQ(1, post_opt_hook_call_count);
}
} // namespace