aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot/tfcompile.bzl
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/aot/tfcompile.bzl')
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl15
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index ee291c12d0..1e22b760b8 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -119,7 +119,7 @@ def tf_library(name, graph, config,
out_nodes_file,
] + freeze_saver_srcs,
outs=[freeze_file],
- cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
+ cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" +
freeze_args),
tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"],
tags=tags,
@@ -130,6 +130,10 @@ def tf_library(name, graph, config,
header_file = name + ".h"
object_file = name + ".o"
ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
+ if type(tfcompile_flags) == type(""):
+ flags = tfcompile_flags
+ else:
+ flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
native.genrule(
name=("gen_" + name),
srcs=[
@@ -148,7 +152,7 @@ def tf_library(name, graph, config,
" --target_triple=" + target_llvm_triple() +
" --out_header=$(@D)/" + header_file +
" --out_object=$(@D)/" + object_file +
- " " + (tfcompile_flags or "")),
+ flags),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
@@ -185,7 +189,7 @@ def tf_library(name, graph, config,
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb +
- " " + (tfcompile_flags or "")),
+ flags),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
@@ -195,8 +199,7 @@ def tf_library(name, graph, config,
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
- need_xla_data_proto = (tfcompile_flags and
- tfcompile_flags.find("--gen_program_shape") != -1)
+ need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
native.cc_library(
name=name,
srcs=[object_file],
@@ -253,7 +256,7 @@ def tf_library(name, graph, config,
],
outs=[test_file],
cmd=("sed " + sed_replace +
- " $(location //tensorflow/compiler/aot:test.cc) " +
+ " $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"),
tags=tags,
)