aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-02-09 14:27:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-09 14:35:39 -0800
commit3590c452ea8485d063874138eec92411297a9abb (patch)
treefae3755a1858a7c31011e3b866108fb96c7ff779 /tensorflow/c/c_test_util.cc
parented5f003cc2c542c3c545369f71d4b57429da33fc (diff)
Enabled XLA for TF C API.
Summary of changes: 1. Set MarkForCompilationPassFlags::tf_xla_cpu_global_jit default to true in C_API unit test env when XLA-execute is intended. Together with setting session config config.graph_options.optimizer_options.global_jit_level to > 0, this turns on XLA for the entire graph (eligible nodes only, with _Arg and _RetVal nodes excluded). We decided against defaulting MarkForCompilationPassFlags::tf_xla_cpu_global_jit to true, due to performance concerns with the single-threaded nature of the XLA CPU backend (see https://www.tensorflow.org/performance/xla/jit#turning_on_jit_compilation). 2. In FindCompilationCandidates() during MarkForCompilationPass, skip compiling any '_Arg'-typed nodes. This is necessary to avoid hitting a "Invalid argument number" error during MarkForCompilationPass. 3. Extended C API based build rules to link in XLA libraries, and added unit test "CAPI.Session_Min_XLA_CPU". Also added some misc improvements and debugging aids. PiperOrigin-RevId: 185193314
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r--tensorflow/c/c_test_util.cc17
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 3c1d5b5bf8..2c5f08d672 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -15,11 +15,13 @@ limitations under the License.
#include "tensorflow/c/c_test_util.h"
+#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/session_options.h"
using tensorflow::GraphDef;
using tensorflow::NodeDef;
@@ -390,8 +392,21 @@ std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def) {
return names;
}
-CSession::CSession(TF_Graph* graph, TF_Status* s) {
+CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) {
TF_SessionOptions* opts = TF_NewSessionOptions();
+ tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
+ tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
+ flags->tf_xla_cpu_global_jit = use_XLA;
+ if (use_XLA) {
+ tensorflow::ConfigProto config;
+ config.mutable_graph_options()
+ ->mutable_optimizer_options()
+ ->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
+ std::string contents;
+ contents.resize(config.ByteSizeLong());
+ config.SerializeToArray(&contents[0], contents.size());
+ TF_SetConfig(opts, contents.data(), contents.size(), s);
+ }
session_ = TF_NewSession(graph, opts, s);
TF_DeleteSessionOptions(opts);
}