diff options
author | Mingsheng Hong <hongm@google.com> | 2018-02-09 14:27:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-09 14:35:39 -0800 |
commit | 3590c452ea8485d063874138eec92411297a9abb (patch) | |
tree | fae3755a1858a7c31011e3b866108fb96c7ff779 /tensorflow/c/BUILD | |
parent | ed5f003cc2c542c3c545369f71d4b57429da33fc (diff) |
Enabled XLA for TF C API.
Summary of changes:
1. Set MarkForCompilationPassFlags::tf_xla_cpu_global_jit default to true in
C_API unit test env when XLA-execute is intended. Together with setting session
config config.graph_options.optimizer_options.global_jit_level to > 0, this
turns on XLA for the entire graph (eligible nodes only, with _Arg and _RetVal
nodes excluded).
We decided against defaulting MarkForCompilationPassFlags::tf_xla_cpu_global_jit
to true, due to performance concerns with the single-threaded nature of the XLA
CPU backend (see
https://www.tensorflow.org/performance/xla/jit#turning_on_jit_compilation).
2. In FindCompilationCandidates() during MarkForCompilationPass, skip compiling
any '_Arg'-typed nodes. This is necessary to avoid hitting a "Invalid argument
number" error during MarkForCompilationPass.
3. Extended C API based build rules to link in XLA libraries, and added unit
test "CAPI.Session_Min_XLA_CPU".
Also added some misc improvements and debugging aids.
PiperOrigin-RevId: 185193314
Diffstat (limited to 'tensorflow/c/BUILD')
-rw-r--r-- | tensorflow/c/BUILD | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 314cbc657c..25a994be3e 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -91,6 +91,12 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], + }) + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], }), ) @@ -141,8 +147,10 @@ tf_cuda_library( ], deps = [ ":c_api", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", "//tensorflow/core:test", ], ) |