aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tensorflow.bzl
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2017-12-04 12:31:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 12:37:12 -0800
commit8f1e63d5629bda4f6c91fdec7a3b8418ed96786e (patch)
tree0f56fddabcd4e5cf91090acf5152cc54b3651e6a /tensorflow/tensorflow.bzl
parenta1c29139ccf441ad4de97c4e7fe2729e6130fcb8 (diff)
Actually use ApiDef when generating Python API.
PiperOrigin-RevId: 177851421
Diffstat (limited to 'tensorflow/tensorflow.bzl')
-rw-r--r--tensorflow/tensorflow.bzl27
1 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index d194b37700..0db915f1b9 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -334,6 +334,7 @@ def tf_gen_op_wrapper_cc(name,
" $$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))")
api_def_args_str = ",".join(api_def_args)
+
native.genrule(
name=name + "_genrule",
outs=[
@@ -469,7 +470,8 @@ def tf_gen_op_wrapper_py(name,
hidden_file=None,
generated_target_name=None,
op_whitelist=[],
- cc_linkopts=[]):
+ cc_linkopts=[],
+ api_def_srcs=[]):
if (hidden or hidden_file) and op_whitelist:
fail('Cannot pass specify both hidden and op_whitelist.')
@@ -502,22 +504,39 @@ def tf_gen_op_wrapper_py(name,
op_list_arg = "''"
op_list_is_whitelist = False
+ # Prepare ApiDef directories to pass to the genrule.
+ if not api_def_srcs:
+ api_def_args_str = ","
+ else:
+ api_def_args = []
+ for api_def_src in api_def_srcs:
+ # Add directory of the first ApiDef source to args.
+ # We are assuming all ApiDefs in a single api_def_src are in the
+ # same directory.
+ api_def_args.append(
+ "$$(dirname $$(echo $(locations " + api_def_src +
+ ") | cut -d\" \" -f1))")
+ api_def_args_str = ",".join(api_def_args)
+
if hidden_file:
# `hidden_file` is file containing a list of op names to be hidden in the
# generated module.
native.genrule(
name=name + "_pygenrule",
outs=[out],
- srcs=[hidden_file],
+ srcs=api_def_srcs + [hidden_file],
tools=[tool_name] + tf_binary_additional_srcs(),
- cmd=("$(location " + tool_name + ") @$(location " + hidden_file + ") " +
+ cmd=("$(location " + tool_name + ") " + api_def_args_str +
+ " @$(location " + hidden_file + ") " +
("1" if require_shape_functions else "0") + " > $@"))
else:
native.genrule(
name=name + "_pygenrule",
outs=[out],
+ srcs=api_def_srcs,
tools=[tool_name] + tf_binary_additional_srcs(),
- cmd=("$(location " + tool_name + ") " + op_list_arg + " " +
+ cmd=("$(location " + tool_name + ") " + api_def_args_str + " " +
+ op_list_arg + " " +
("1" if require_shape_functions else "0") + " " +
("1" if op_list_is_whitelist else "0") + " > $@"))