aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r--tensorflow/c/c_api_experimental.cc27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 95b04f9058..170046c802 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -57,6 +57,33 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
}
}
+TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
+ unsigned char gpu_memory_allow_growth) {
+ tensorflow::ConfigProto config;
+ auto* optimizer_options =
+ config.mutable_graph_options()->mutable_optimizer_options();
+ if (enable_xla_compilation) {
+ optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
+
+ // These XLA flags are needed to trigger XLA properly from C (more generally
+ // non-Python) clients. If this API is called again with `enable` set to
+ // false, it is safe to keep these flag values as is.
+ tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
+ tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
+ flags->tf_xla_cpu_global_jit = true;
+ flags->tf_xla_min_cluster_size = 1;
+ } else {
+ optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
+ }
+
+ auto* gpu_options = config.mutable_gpu_options();
+ gpu_options->set_allow_growth(gpu_memory_allow_growth);
+
+ TF_Buffer* ret = TF_NewBuffer();
+ TF_CHECK_OK(MessageToBuffer(config, ret));
+ return ret;
+}
+
const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
tensorflow::mutex_lock c(graph->mu);
const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();