aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/llvm_compiler.h6
-rw-r--r--tensorflow/compiler/xla/tests/BUILD34
-rw-r--r--tensorflow/compiler/xla/tests/codegen_test_base.cc93
-rw-r--r--tensorflow/compiler/xla/tests/codegen_test_base.h33
-rw-r--r--tensorflow/compiler/xla/tests/filecheck.cc77
-rw-r--r--tensorflow/compiler/xla/tests/filecheck.h32
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc82
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.h67
8 files changed, 314 insertions, 110 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h
index 9d9cb69148..b2e72871c1 100644
--- a/tensorflow/compiler/xla/service/llvm_compiler.h
+++ b/tensorflow/compiler/xla/service/llvm_compiler.h
@@ -42,15 +42,21 @@ class LLVMCompiler : public Compiler {
void SetPreOptimizationHook(ModuleHook hook) {
CHECK(!user_pre_optimization_hook_)
<< "Pre-optimization hook is already set";
+ CHECK(hook) << "hook cannot be null";
user_pre_optimization_hook_ = hook;
}
+ void RemovePreOptimizationHook() { user_pre_optimization_hook_ = nullptr; }
+
void SetPostOptimizationHook(ModuleHook hook) {
CHECK(!user_post_optimization_hook_)
<< "Post-optimization hook is already set";
+ CHECK(hook) << "hook cannot be null";
user_post_optimization_hook_ = hook;
}
+ void RemovePostOptimizationHook() { user_post_optimization_hook_ = nullptr; }
+
protected:
ModuleHook user_pre_optimization_hook_;
ModuleHook user_post_optimization_hook_;
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 2b35cbaaa8..0a2d337752 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -179,22 +179,43 @@ cc_library(
)
cc_library(
+ name = "llvm_irgen_test_base",
+ testonly = True,
+ srcs = ["llvm_irgen_test_base.cc"],
+ hdrs = ["llvm_irgen_test_base.h"],
+ deps = [
+ ":codegen_test_base",
+ ":filecheck",
+ "//tensorflow/compiler/xla/service:llvm_compiler",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
name = "codegen_test_base",
testonly = True,
srcs = ["codegen_test_base.cc"],
hdrs = ["codegen_test_base.h"],
+ deps = [
+ ":hlo_test_base",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ ],
+)
+
+cc_library(
+ name = "filecheck",
+ testonly = True,
+ srcs = ["filecheck.cc"],
+ hdrs = ["filecheck.h"],
data = [
"@llvm//:FileCheck",
],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
- "//tensorflow/compiler/xla/service:backend",
- "//tensorflow/compiler/xla/service:compiler",
- "//tensorflow/compiler/xla/service:executable",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -1327,6 +1348,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@llvm//:core",
],
)
diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc
index 90767c4a17..a52be3ffd1 100644
--- a/tensorflow/compiler/xla/tests/codegen_test_base.cc
+++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc
@@ -15,90 +15,25 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/codegen_test_base.h"
-#include <stdlib.h>
-#include <utility>
-
-#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/service/backend.h"
-#include "tensorflow/compiler/xla/service/compiler.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/subprocess.h"
-#include "tensorflow/core/platform/test.h"
-
namespace xla {
-std::unique_ptr<HloModule> CodegenTestBase::CreateNewModuleWithEmbeddedIr(
- bool ftz) {
- HloModuleConfig config;
- auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
- debug_options.set_xla_embed_ir_in_executable(true);
- debug_options.set_xla_gpu_ftz(ftz);
- // TODO(b/38354253): Change tests to use Parameters instead of Constants.
- debug_options.add_xla_disable_hlo_passes("constant_folding");
- config.set_debug_options(debug_options);
-
- return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
- config);
-}
-
-void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
- const string& pattern) {
- std::unique_ptr<Executable> executable =
- CompileToExecutable(std::move(hlo_module));
- string ir_module_string = GetIrFromExecutable(*executable);
- RunFileCheck(ir_module_string, pattern);
-}
-
-std::unique_ptr<Executable> CodegenTestBase::CompileToExecutable(
+StatusOr<std::unique_ptr<Executable>> CodegenTestBase::CompileToExecutable(
std::unique_ptr<HloModule> hlo_module) {
- return backend_->compiler()
- ->Compile(std::move(hlo_module), backend_->default_stream_executor())
- .ConsumeValueOrDie();
+ return backend_->compiler()->Compile(std::move(hlo_module),
+ backend_->default_stream_executor());
}
-void CodegenTestBase::RunFileCheck(const string& input, const string& pattern) {
- using tensorflow::io::JoinPath;
-
- // Write input to a temporary file.
- char tempdir_template[] = "/tmp/ir_testXXXXXX";
- char* tempdir_name = mkdtemp(tempdir_template);
- CHECK_NOTNULL(tempdir_name);
- string pattern_path = JoinPath(tempdir_name, "xla_hlo_test_ir_pattern");
- TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
- pattern_path, pattern));
-
- // Invoke FileCheck to check whether input matches `pattern`.
- const char* file_check_path_suffix = "external/llvm/FileCheck";
- string file_check_path;
- if (const char* test_srcdir = getenv("TEST_SRCDIR")) {
- file_check_path = JoinPath(test_srcdir, file_check_path_suffix);
- } else {
- file_check_path = file_check_path_suffix;
- }
-
- tensorflow::SubProcess file_check_process;
- file_check_process.SetProgram(file_check_path,
- {file_check_path, pattern_path});
- file_check_process.SetChannelAction(tensorflow::CHAN_STDIN,
- tensorflow::ACTION_PIPE);
- file_check_process.SetChannelAction(tensorflow::CHAN_STDERR,
- tensorflow::ACTION_PIPE);
- CHECK(file_check_process.Start());
- string standard_error;
- int exit_status = file_check_process.Communicate(
- /*stdin_input=*/&input, /*stdout_output=*/nullptr,
- /*stderr_output=*/&standard_error);
-
- // FileCheck returns 0 when the inputs match. If matching failed, we output
- // the error message generated by FileCheck.
- SCOPED_TRACE(tensorflow::strings::StrCat("Input to FileCheck:\n", input));
- EXPECT_EQ(0, exit_status) << standard_error;
+StatusOr<std::unique_ptr<AotCompilationResult>>
+CodegenTestBase::CompileToAotCompilationResult(
+ std::unique_ptr<HloModule> hlo_module,
+ const AotCompilationOptions& options) {
+ std::vector<std::unique_ptr<HloModule>> hlo_modules;
+ hlo_modules.push_back(std::move(hlo_module));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<std::unique_ptr<AotCompilationResult>> results,
+ backend_->compiler()->CompileAheadOfTime(std::move(hlo_modules),
+ options));
+ return std::move(results.front());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h
index fa073cd91e..441fcd6890 100644
--- a/tensorflow/compiler/xla/tests/codegen_test_base.h
+++ b/tensorflow/compiler/xla/tests/codegen_test_base.h
@@ -17,42 +17,25 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_
#include <memory>
-#include <string>
+#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
-// Tests that verify IR emitted by the CPU/GPU backend is as expected.
+// Provides access to both the JIT and the AOT compiler for testing.
class CodegenTestBase : public HloTestBase {
protected:
- // Like HloTestBase::CreateNewModule, but also sets the "embed ir in
- // executable" flag to true, since this is needed for codegen tests.
- // The optional ftz flags configures whether these modules have their ftz
- // option turned on.
- std::unique_ptr<HloModule> CreateNewModuleWithEmbeddedIr(bool ftz = false);
-
- // Returns the embedded LLVM IR from the given executable. Codegen tests must
- // override this method, but execution tests do not have to because they do
- // not examine the embedded IR.
- virtual string GetIrFromExecutable(const Executable& executable) = 0;
-
- // Compiles the given HLO module to LLVM IR and verifies the IR matches the
- // given pattern. `pattern` is in the FileCheck pattern matching syntax
- // (http://llvm.org/docs/CommandGuide/FileCheck.html).
- void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
- const string& pattern);
-
- protected:
- // Compiles hlo_module to an executable, CHECK-failing if this fails.
- std::unique_ptr<Executable> CompileToExecutable(
+ // Compiles hlo_module with the JIT compiler.
+ StatusOr<std::unique_ptr<Executable>> CompileToExecutable(
std::unique_ptr<HloModule> hlo_module);
- // Runs FileCheck with the given pattern over the given string and EXPECTs
- // that FileCheck succeeded in matching the input.
- void RunFileCheck(const string& input, const string& pattern);
+ // Compiles hlo_module with the AOT compiler.
+ StatusOr<std::unique_ptr<AotCompilationResult>> CompileToAotCompilationResult(
+ std::unique_ptr<HloModule> hlo_module,
+ const AotCompilationOptions& options);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc
new file mode 100644
index 0000000000..407b5f4ada
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/filecheck.cc
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+
+#include <cstdlib>
+
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/subprocess.h"
+
+namespace xla {
+
+StatusOr<bool> RunFileCheck(const string& input, const string& pattern) {
+ using tensorflow::io::JoinPath;
+
+ // Generate an input file for the FileCheck pattern.
+ string pattern_path;
+ auto env = tensorflow::Env::Default();
+ if (!env->LocalTempFilename(&pattern_path)) {
+ return tensorflow::errors::Internal("couldn't get a pattern file name");
+ }
+ TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, pattern_path, pattern));
+
+ // Invoke FileCheck to check whether input matches `pattern`.
+ const char* file_check_path_suffix = "external/llvm/FileCheck";
+ string file_check_path;
+ if (const char* test_srcdir = getenv("TEST_SRCDIR")) {
+ file_check_path = JoinPath(test_srcdir, file_check_path_suffix);
+ } else {
+ file_check_path = file_check_path_suffix;
+ }
+
+ tensorflow::SubProcess file_check_process;
+ file_check_process.SetProgram(file_check_path,
+ {file_check_path, pattern_path});
+ file_check_process.SetChannelAction(tensorflow::CHAN_STDIN,
+ tensorflow::ACTION_PIPE);
+ file_check_process.SetChannelAction(tensorflow::CHAN_STDERR,
+ tensorflow::ACTION_PIPE);
+ if (!file_check_process.Start()) {
+ return tensorflow::errors::Internal("couldn't start FileCheck");
+ }
+
+ string standard_error;
+ int exit_status = file_check_process.Communicate(
+ /*stdin_input=*/&input, /*stdout_output=*/nullptr,
+ /*stderr_output=*/&standard_error);
+
+ // FileCheck returns 0 when the inputs match. If matching failed, log
+ // the error message generated by FileCheck and the inputs.
+ bool succeeded = (exit_status == 0);
+ if (!succeeded) {
+ VLOG(1) << "FileCheck error: " << standard_error;
+ VLOG(1) << "FileCheck input was:";
+ XLA_VLOG_LINES(1, input);
+ VLOG(1) << "FileCheck pattern was:";
+ XLA_VLOG_LINES(1, pattern);
+ }
+ return succeeded;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/filecheck.h b/tensorflow/compiler/xla/tests/filecheck.h
new file mode 100644
index 0000000000..599bf57ad3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/filecheck.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// Runs FileCheck with the given pattern over given input string. Provided that
+// FileCheck can execute, returns true if and only if FileCheck succeeded in
+// matching the input.
+StatusOr<bool> RunFileCheck(const string& input, const string& pattern);
+
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
new file mode 100644
index 0000000000..98dd9613a7
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h"
+
+#include <functional>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+
+void LLVMIRGenTestBase::SetIrHook(bool match_optimized_ir) {
+ auto llvm_compiler = GetLLVMCompiler();
+ using std::placeholders::_1;
+
+ // Add the IR inspection hook to the LLVM compiler.
+ if (match_optimized_ir) {
+ llvm_compiler->SetPostOptimizationHook(
+ std::bind(&LLVMIRGenTestBase::IrHook, this, _1));
+ } else {
+ llvm_compiler->SetPreOptimizationHook(
+ std::bind(&LLVMIRGenTestBase::IrHook, this, _1));
+ }
+}
+
+void LLVMIRGenTestBase::ResetIrHook() {
+ auto llvm_compiler = GetLLVMCompiler();
+
+ llvm_compiler->RemovePreOptimizationHook();
+ llvm_compiler->RemovePostOptimizationHook();
+}
+
+void LLVMIRGenTestBase::CompileAndVerifyIr(
+ std::unique_ptr<HloModule> hlo_module, const string& pattern,
+ bool match_optimized_ir) {
+ SetIrHook(match_optimized_ir);
+ ASSERT_TRUE(CompileToExecutable(std::move(hlo_module)).ok());
+ ResetIrHook();
+
+ StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
+ ASSERT_TRUE(filecheck_result.ok());
+ EXPECT_TRUE(filecheck_result.ValueOrDie());
+}
+
+void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr(
+ std::unique_ptr<HloModule> hlo_module, const AotCompilationOptions& options,
+ const string& pattern, bool match_optimized_ir) {
+ SetIrHook(match_optimized_ir);
+ ASSERT_TRUE(
+ CompileToAotCompilationResult(std::move(hlo_module), options).ok());
+ ResetIrHook();
+
+ StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
+ ASSERT_TRUE(filecheck_result.ok());
+ EXPECT_TRUE(filecheck_result.ValueOrDie());
+}
+
+LLVMCompiler* LLVMIRGenTestBase::GetLLVMCompiler() const {
+ return static_cast<LLVMCompiler*>(backend_->compiler());
+}
+
+Status LLVMIRGenTestBase::IrHook(const llvm::Module& module) {
+ ir_ = llvm_ir::DumpModuleToString(module);
+ return Status::OK();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h
new file mode 100644
index 0000000000..f0a0df76ac
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h
@@ -0,0 +1,67 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LLVM_IRGEN_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_LLVM_IRGEN_TEST_BASE_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "tensorflow/compiler/xla/tests/codegen_test_base.h"
+
+namespace xla {
+
+// Tests that verify IR emitted by the CPU/GPU backend is as expected.
+class LLVMIRGenTestBase : public CodegenTestBase {
+ protected:
+ // Compiles the given HLO module to LLVM IR and verifies the IR matches the
+ // given pattern. `pattern` is in the FileCheck pattern matching syntax
+ // (http://llvm.org/docs/CommandGuide/FileCheck.html).
+ //
+ // This function invokes the JIT compiler.
+ //
+ // If `match_optimized_ir` is true, match the version of the IR after internal
+ // optimizations are applied; otherwise, the IR before optimizations is
+ // matched.
+ void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
+ const string& pattern, bool match_optimized_ir);
+
+ // Compiles the given HLO module to LLVM IR and verifies the IR matches the
+ // given pattern. `pattern` is in the FileCheck pattern matching syntax
+ // (http://llvm.org/docs/CommandGuide/FileCheck.html).
+ //
+ // This function invokes the AOT compiler, with options in `options`.
+ //
+ // If `match_optimized_ir` is true, match the version of the IR after internal
+ // optimizations are applied; otherwise, the IR before optimizations is
+ // matched.
+ void CompileAheadOfTimeAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
+ const AotCompilationOptions& options,
+ const string& pattern,
+ bool match_optimized_ir);
+
+ private:
+ LLVMCompiler* GetLLVMCompiler() const;
+
+ void SetIrHook(bool match_optimized_ir);
+ void ResetIrHook();
+
+ string ir_;
+ Status IrHook(const llvm::Module& module);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_LLVM_IRGEN_TEST_BASE_H_