aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/BUILD
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-02-09 14:27:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-09 14:35:39 -0800
commit3590c452ea8485d063874138eec92411297a9abb (patch)
treefae3755a1858a7c31011e3b866108fb96c7ff779 /tensorflow/c/BUILD
parented5f003cc2c542c3c545369f71d4b57429da33fc (diff)
Enabled XLA for TF C API.
Summary of changes: 1. Set MarkForCompilationPassFlags::tf_xla_cpu_global_jit default to true in C_API unit test env when XLA-execute is intended. Together with setting session config config.graph_options.optimizer_options.global_jit_level to > 0, this turns on XLA for the entire graph (eligible nodes only, with _Arg and _RetVal nodes excluded). We decided against defaulting MarkForCompilationPassFlags::tf_xla_cpu_global_jit to true, due to performance concerns with the single-threaded nature of the XLA CPU backend (see https://www.tensorflow.org/performance/xla/jit#turning_on_jit_compilation). 2. In FindCompilationCandidates() during MarkForCompilationPass, skip compiling any '_Arg'-typed nodes. This is necessary to avoid hitting a "Invalid argument number" error during MarkForCompilationPass. 3. Extended C API based build rules to link in XLA libraries, and added unit test "CAPI.Session_Min_XLA_CPU". Also added some misc improvements and debugging aids. PiperOrigin-RevId: 185193314
Diffstat (limited to 'tensorflow/c/BUILD')
-rw-r--r--tensorflow/c/BUILD8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 314cbc657c..25a994be3e 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -91,6 +91,12 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
+ }) + select({
+ "//tensorflow:with_xla_support": [
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/jit",
+ ],
+ "//conditions:default": [],
}),
)
@@ -141,8 +147,10 @@ tf_cuda_library(
],
deps = [
":c_api",
+ "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
"//tensorflow/core:test",
],
)