aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-02-16 22:05:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 22:09:20 -0800
commitfa8c4d16288e3bee4a014b4d51d22dd361721ff4 (patch)
treef42d8b707cc5c7d9cb478fb36065d1b22fe95253 /tensorflow/c/c_api_experimental.cc
parent02bbb131b78fb0924675809ed5b549e594a51ac1 (diff)
Added an experimental C API TF_EnableXLACompilation() to enable XLA compilation.
Also ran "buildozer warn //third_party/tensorflow/c/BUILD" and removed an unused symbol. PiperOrigin-RevId: 186081948
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r--tensorflow/c/c_api_experimental.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
new file mode 100644
index 0000000000..be7f85a5bb
--- /dev/null
+++ b/tensorflow/c/c_api_experimental.cc
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/c_api_experimental.h"
+
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
+ tensorflow::ConfigProto& config = options->options.config;
+ auto* optimizer_options =
+ config.mutable_graph_options()->mutable_optimizer_options();
+ if (enable) {
+ optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
+
+ // These XLA flags are needed to trigger XLA properly from C (more generally
+ // non-Python) clients. If this API is called again with `enable` set to
+ // false, it is safe to keep these flag values as is.
+ tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
+ tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
+ flags->tf_xla_cpu_global_jit = true;
+ flags->tf_xla_min_cluster_size = 1;
+ } else {
+ optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
+ }
+}