aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eli Bendersky <eliben@google.com>2017-06-23 08:40:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-23 08:44:15 -0700
commitcfe28e09f36e54d55d08e666392d19c5c46c67db (patch)
treec80c5cad299a228bf5c1041d9f3ad77c56e815ea
parent363854b00b6c29d9c3ff3a328f03cb4529ee7594 (diff)
[XLA] Remove unused xla_cpu flag and move another to DebugOptions.
PiperOrigin-RevId: 159952124
-rw-r--r--tensorflow/compiler/aot/BUILD1
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc2
-rw-r--r--tensorflow/compiler/xla/legacy_flags/BUILD12
-rw-r--r--tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc71
-rw-r--r--tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h51
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD5
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h8
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc11
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc11
-rw-r--r--tensorflow/compiler/xla/tests/BUILD12
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc2
-rw-r--r--tensorflow/compiler/xla/xla.proto8
16 files changed, 40 insertions, 198 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 77532a32fa..7dfed901fc 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -128,7 +128,6 @@ cc_library(
":tfcompile_lib",
":tfcompile_proto",
"//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
"//tensorflow/compiler/xla/legacy_flags:service_flags",
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 4c3167b502..9ec205a126 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
@@ -133,7 +132,6 @@ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list);
- xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list);
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::legacy_flags::AppendServiceFlags(&flag_list);
diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD
index a54628060f..33dc94fe00 100644
--- a/tensorflow/compiler/xla/legacy_flags/BUILD
+++ b/tensorflow/compiler/xla/legacy_flags/BUILD
@@ -80,18 +80,6 @@ cc_library(
)
cc_library(
- name = "cpu_runtime_flags",
- srcs = ["cpu_runtime_flags.cc"],
- hdrs = ["cpu_runtime_flags.h"],
- deps =
- [
- ":parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "stream_assignment_flags",
srcs = ["stream_assignment_flags.cc"],
hdrs = ["stream_assignment_flags.h"],
diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc
deleted file mode 100644
index d7817c5d54..0000000000
--- a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc
+++ /dev/null
@@ -1,71 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for XLA's cpu_runtime module.
-
-#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
-#include <vector>
-
-#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace xla {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static CpuRuntimeFlags* flags;
-static std::vector<tensorflow::Flag>* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new CpuRuntimeFlags;
- flags->xla_cpu_use_eigen = true;
- flags->xla_cpu_multi_thread_eigen = true;
- flag_list = new std::vector<tensorflow::Flag>({
- tensorflow::Flag(
- "xla_cpu_use_eigen", &flags->xla_cpu_use_eigen,
- "Use Eigen for matrix multiply on the CPU platform. This "
- "is a useful hack for performance comparisons against "
- "XLA's implementation."),
- tensorflow::Flag(
- "xla_cpu_multi_thread_eigen", &flags->xla_cpu_multi_thread_eigen,
- "When generating calls to Eigen for matmul and conv, should "
- "single or multi-threaded eigen be used? "
- "Only used when --xla_cpu_use_eigen is true."),
- });
- ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with XLA's cpu_runtime
-// module.
-void AppendCpuRuntimeFlags(std::vector<tensorflow::Flag>* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the CpuRuntimeFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-CpuRuntimeFlags* GetCpuRuntimeFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h
deleted file mode 100644
index e3ff30da36..0000000000
--- a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
-#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
-
-// Legacy flags for the XLA's cpu_runtime module.
-
-#include <vector>
-
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace xla {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with XLA's cpu_runtime
-// module.
-void AppendCpuRuntimeFlags(std::vector<tensorflow::Flag>* flag_list);
-
-// The values of flags associated with XLA's cpu_runtime module.
-typedef struct {
- // Use Eigen for matrix multiply on the CPU platform. This is a useful hack
- // for performance comparisons against XLA's implementation.
- bool xla_cpu_use_eigen;
- // When generating calls to Eigen for matmul and conv, should single or
- // multi-threaded eigen be used? Only used when --xla_cpu_use_eigen is true.
- bool xla_cpu_multi_thread_eigen;
-} CpuRuntimeFlags;
-
-// Return a pointer to the CpuRuntimeFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-CpuRuntimeFlags* GetCpuRuntimeFlags();
-
-} // namespace legacy_flags
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index bd026a5da1..01e9d4010a 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -32,6 +32,8 @@ struct DebugOptionsFlags {
string xla_dump_ir_to;
string xla_dump_debug_json_to;
+ bool xla_cpu_multi_thread_eigen;
+
string xla_gpu_cuda_data_dir;
bool xla_gpu_ftz;
@@ -55,6 +57,7 @@ void AllocateFlags() {
flag_values->xla_embed_ir_in_executable = false;
flag_values->xla_dump_ir_to = "";
flag_values->xla_dump_debug_json_to = "";
+ flag_values->xla_cpu_multi_thread_eigen = true;
flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib";
flag_values->xla_gpu_ftz = false;
flag_values->xla_backend_extra_options = "";
@@ -82,6 +85,10 @@ void AllocateFlags() {
"Embed the compiler IR as a string in the executable."),
tensorflow::Flag("xla_dump_ir_to", &flag_values->xla_dump_ir_to,
"Dump the compiler IR into this file/path."),
+ tensorflow::Flag("xla_cpu_multi_thread_eigen",
+ &flag_values->xla_cpu_multi_thread_eigen,
+ "When generating calls to Eigen in the CPU backend, "
+ "use multi-threaded Eigen mode."),
tensorflow::Flag("xla_gpu_cuda_data_dir",
&flag_values->xla_gpu_cuda_data_dir,
"If non-empty, speficies a local directory containing "
@@ -129,6 +136,8 @@ xla::DebugOptions GetDebugOptionsFromFlags() {
flag_values->xla_embed_ir_in_executable);
options.set_xla_dump_ir_to(flag_values->xla_dump_ir_to);
options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to);
+ options.set_xla_cpu_multi_thread_eigen(
+ flag_values->xla_cpu_multi_thread_eigen);
options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir);
options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz);
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 6fa3acd85d..405057f918 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -197,7 +197,6 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
@@ -228,7 +227,6 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
@@ -289,7 +287,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@llvm//:analysis",
@@ -484,7 +481,6 @@ cc_library(
":cpu_runtime",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:window_util",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/service:hlo",
],
)
@@ -511,7 +507,6 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index d78a54427a..1f78039e46 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -35,7 +35,6 @@ limitations under the License.
#include "external/llvm/include/llvm/Transforms/IPO.h"
#include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h"
#include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h"
-#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
@@ -145,12 +144,7 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
{"llvm.tanh.f32", runtime::kTanhV8F32, 8},
};
- // Our vectorized library calls are currently implement by calling into Eigen.
- // As such, only emit calls to these routines if --xla_cpu_use_eigen is
- // enabled.
- legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
- if (flags->xla_cpu_use_eigen &&
- (arch == llvm::Triple::x86 || llvm::Triple::x86_64)) {
+ if (arch == llvm::Triple::x86 || llvm::Triple::x86_64) {
llvm::SmallVector<llvm::StringRef, 32> features;
feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
if (std::find(features.begin(), features.end(), "+sse4.1") !=
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index cdf43587b6..069979c661 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
-#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -30,11 +29,6 @@ namespace xla {
namespace cpu {
StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
- legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
- if (!flags->xla_cpu_use_eigen) {
- return false;
- }
-
bool changed = false;
for (HloInstruction* hlo :
module->entry_computation()->MakeInstructionPostOrder()) {
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 420f9cebc5..c21b8a9add 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -22,9 +22,9 @@ limitations under the License.
#include "external/llvm/include/llvm/IR/Instructions.h"
#include "external/llvm/include/llvm/IR/Module.h"
#include "external/llvm/include/llvm/IR/Value.h"
-#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -44,7 +44,8 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
const llvm_ir::IrArray& lhs_array,
const llvm_ir::IrArray& rhs_array,
llvm::Value* executable_run_options_value,
- llvm::IRBuilder<>* ir_builder)
+ llvm::IRBuilder<>* ir_builder,
+ const HloModuleConfig& hlo_module_config)
: dot_(dot),
transpose_lhs_(transpose_lhs),
transpose_rhs_(transpose_rhs),
@@ -52,18 +53,20 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
lhs_array_(lhs_array),
rhs_array_(rhs_array),
executable_run_options_value_(executable_run_options_value),
- ir_builder_(ir_builder) {}
+ ir_builder_(ir_builder),
+ hlo_module_config_(hlo_module_config) {}
/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation(
const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
const llvm_ir::IrArray& rhs_array,
- llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder) {
+ llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
+ const HloModuleConfig& hlo_module_config) {
PrimitiveType type = target_array.GetShape().element_type();
TF_RET_CHECK(F32 == type || F64 == type);
DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
lhs_array, rhs_array, executable_run_options_value,
- ir_builder);
+ ir_builder, hlo_module_config);
return dot_emitter.Emit();
}
@@ -233,20 +236,20 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
// The two transpose_... parameters are actually booleans, but we use int32
// to avoid target-dependent calling convention details.
- legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
- bool multi_threaded = flags->xla_cpu_multi_thread_eigen;
+ bool multi_threaded_eigen =
+ hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
PrimitiveType type = target_array_.GetShape().element_type();
llvm::Type* float_type;
const char* fn_name;
switch (type) {
case F32:
- fn_name = multi_threaded
+ fn_name = multi_threaded_eigen
? runtime::kEigenMatmulF32SymbolName
: runtime::kEigenSingleThreadedMatmulF32SymbolName;
float_type = ir_builder_->getFloatTy();
break;
case F64:
- fn_name = multi_threaded
+ fn_name = multi_threaded_eigen
? runtime::kEigenMatmulF64SymbolName
: runtime::kEigenSingleThreadedMatmulF64SymbolName;
float_type = ir_builder_->getDoubleTy();
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index 44dfe5f2a9..b614716380 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "external/llvm/include/llvm/IR/IRBuilder.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/types.h"
@@ -39,7 +40,8 @@ class DotOpEmitter {
const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
const llvm_ir::IrArray& rhs_array,
- llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder);
+ llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
+ const HloModuleConfig& hlo_module_config);
private:
DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
@@ -47,7 +49,8 @@ class DotOpEmitter {
const llvm_ir::IrArray& lhs_array,
const llvm_ir::IrArray& rhs_array,
llvm::Value* executable_run_options_value,
- llvm::IRBuilder<>* ir_builder);
+ llvm::IRBuilder<>* ir_builder,
+ const HloModuleConfig& hlo_module_config);
// Emits the IR to perform the dot operation.
tensorflow::Status Emit();
@@ -82,6 +85,7 @@ class DotOpEmitter {
const llvm_ir::IrArray& rhs_array_;
llvm::Value* executable_run_options_value_;
llvm::IRBuilder<>* ir_builder_;
+ const HloModuleConfig& hlo_module_config_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index 2d855d0eb1..859329e2c1 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -26,11 +25,6 @@ namespace cpu {
bool PotentiallyImplementedAsEigenConvolution(
const HloInstruction& convolution) {
- legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
- if (!flags->xla_cpu_use_eigen) {
- return false;
- }
-
// The following conditions are necessary (but not sufficient) for
// implementing `convolution` with Eigen convolution:
// - the input and kernel have a non-zero number of elements.
@@ -82,11 +76,6 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
} // namespace
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
- legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
- if (!flags->xla_cpu_use_eigen) {
- return false;
- }
-
// For certain types of Dot, we can call Eigen
if (hlo.opcode() == HloOpcode::kDot) {
const Shape& lhs_shape = hlo.operand(0)->shape();
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index fee5fd8830..c81a368992 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -33,7 +33,6 @@ limitations under the License.
#include "external/llvm/include/llvm/IR/Intrinsics.h"
#include "external/llvm/include/llvm/IR/LLVMContext.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
@@ -777,7 +776,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
// Dot operation is complicated so we delegate to a helper class.
TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
*dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
- lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_));
+ lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_,
+ hlo_module_config_));
emitted_value_[dot] = target_address;
return Status::OK();
@@ -862,9 +862,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
int64_type, int64_type, int64_type, int64_type,
int64_type, int64_type, int64_type, int64_type},
/*isVarArg=*/false);
- legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
+ bool multi_threaded_eigen =
+ hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
const char* fn_name =
- (flags->xla_cpu_multi_thread_eigen
+ (multi_threaded_eigen
? runtime::kEigenConvF32SymbolName
: runtime::kEigenSingleThreadedConvF32SymbolName);
llvm::Function* conv_func = llvm::cast<llvm::Function>(
@@ -1525,7 +1526,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
*dot, dot->operand(0)->IsRank2Transpose(),
dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array,
- GetExecutableRunOptionsArgument(), &ir_builder_));
+ GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_));
emitted_value_[fusion] = target_address;
return Status::OK();
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index e54f02b7fb..91e353c330 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -495,7 +495,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
@@ -512,10 +511,6 @@ xla_test(
xla_test(
name = "dot_operation_runtime_test",
srcs = ["dot_operation_test.cc"],
- backend_args = {
- "cpu": ["--xla_cpu_use_eigen"],
- "cpu_parallel": ["--xla_cpu_use_eigen"],
- },
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
@@ -523,7 +518,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
@@ -541,11 +535,9 @@ xla_test(
srcs = ["dot_operation_test.cc"],
backend_args = {
"cpu": [
- "--xla_cpu_use_eigen",
"--xla_cpu_multi_thread_eigen=false",
],
"cpu_parallel": [
- "--xla_cpu_use_eigen",
"--xla_cpu_multi_thread_eigen=false",
],
},
@@ -556,7 +548,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
@@ -573,11 +564,9 @@ xla_test(
srcs = ["dot_operation_test.cc"],
backend_args = {
"cpu": [
- "--xla_cpu_use_eigen",
"--xla_default_layout=major2minor",
],
"cpu_parallel": [
- "--xla_cpu_use_eigen",
"--xla_default_layout=major2minor",
],
},
@@ -588,7 +577,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index b06b5c5f47..7abef6a27b 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -461,7 +460,6 @@ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendLayoutUtilFlags(&flag_list);
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
- xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 76079c44cd..6dd7999b67 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -54,11 +54,15 @@ message DebugOptions {
// Dump compilation artifacts as JSON into this directory.
string xla_dump_debug_json_to = 7;
+ // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen
+ // mode.
+ bool xla_cpu_multi_thread_eigen = 8;
+
// Path to directory with cuda/ptx tools and libraries.
- string xla_gpu_cuda_data_dir = 8;
+ string xla_gpu_cuda_data_dir = 9;
// Enable flush-to-zero semantics in the GPU backend.
- bool xla_gpu_ftz = 9;
+ bool xla_gpu_ftz = 10;
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.