diff options
author | Anna R <annarev@google.com> | 2017-12-04 12:31:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-04 12:37:12 -0800 |
commit | 8f1e63d5629bda4f6c91fdec7a3b8418ed96786e (patch) | |
tree | 0f56fddabcd4e5cf91090acf5152cc54b3651e6a /tensorflow/tensorflow.bzl | |
parent | a1c29139ccf441ad4de97c4e7fe2729e6130fcb8 (diff) |
Actually use ApiDef when generating Python API.
PiperOrigin-RevId: 177851421
Diffstat (limited to 'tensorflow/tensorflow.bzl')
-rw-r--r-- | tensorflow/tensorflow.bzl | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d194b37700..0db915f1b9 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -334,6 +334,7 @@ def tf_gen_op_wrapper_cc(name, " $$(dirname $$(echo $(locations " + api_def_src + ") | cut -d\" \" -f1))") api_def_args_str = ",".join(api_def_args) + native.genrule( name=name + "_genrule", outs=[ @@ -469,7 +470,8 @@ def tf_gen_op_wrapper_py(name, hidden_file=None, generated_target_name=None, op_whitelist=[], - cc_linkopts=[]): + cc_linkopts=[], + api_def_srcs=[]): if (hidden or hidden_file) and op_whitelist: fail('Cannot pass specify both hidden and op_whitelist.') @@ -502,22 +504,39 @@ def tf_gen_op_wrapper_py(name, op_list_arg = "''" op_list_is_whitelist = False + # Prepare ApiDef directories to pass to the genrule. + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + "$$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))") + api_def_args_str = ",".join(api_def_args) + if hidden_file: # `hidden_file` is file containing a list of op names to be hidden in the # generated module. native.genrule( name=name + "_pygenrule", outs=[out], - srcs=[hidden_file], + srcs=api_def_srcs + [hidden_file], tools=[tool_name] + tf_binary_additional_srcs(), - cmd=("$(location " + tool_name + ") @$(location " + hidden_file + ") " + + cmd=("$(location " + tool_name + ") " + api_def_args_str + + " @$(location " + hidden_file + ") " + ("1" if require_shape_functions else "0") + " > $@")) else: native.genrule( name=name + "_pygenrule", outs=[out], + srcs=api_def_srcs, tools=[tool_name] + tf_binary_additional_srcs(), - cmd=("$(location " + tool_name + ") " + op_list_arg + " " + + cmd=("$(location " + tool_name + ") " + api_def_args_str + " " + + op_list_arg + " " + ("1" if require_shape_functions else "0") + " " + ("1" if op_list_is_whitelist else "0") + " > $@")) |