aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-09-07 12:20:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 12:24:57 -0700
commita65d3dd42122d3a58985d56118d58c5b4224f38f (patch)
tree9c7bf3a1a3dfa68d73af747b281a5ae327868332 /tensorflow/python/tools
parent0a375d94b6fd4c3cd0bd5d0a301b3acc65b96d78 (diff)
Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, then bazel will build TensorFlow API version 2.0. In all other cases, it would build API version 1.*.
PiperOrigin-RevId: 212016666
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/api/generator/api_gen.bzl34
1 files changed, 19 insertions, 15 deletions
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
index 2810d83bd2..271cf2afaf 100644
--- a/tensorflow/python/tools/api/generator/api_gen.bzl
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -12,10 +12,15 @@ ESTIMATOR_API_INIT_FILES = [
# END GENERATED ESTIMATOR FILES
]
+def get_compat_files(
+ file_paths,
+ compat_api_version):
+ """Prepends compat/v<compat_api_version> to file_paths."""
+ return ["compat/v%d/%s" % (compat_api_version, f) for f in file_paths]
+
def gen_api_init_files(
name,
output_files = TENSORFLOW_API_INIT_FILES,
- compat_output_files = {},
root_init_template = None,
srcs = [],
api_name = "tensorflow",
@@ -23,7 +28,8 @@ def gen_api_init_files(
compat_api_versions = [],
package = "tensorflow.python",
package_dep = "//tensorflow/python:no_contrib",
- output_package = "tensorflow"):
+ output_package = "tensorflow",
+ output_dir = ""):
"""Creates API directory structure and __init__.py files.
Creates a genrule that generates a directory structure with __init__.py
@@ -37,8 +43,6 @@ def gen_api_init_files(
tf_export. For e.g. if an op is decorated with
@tf_export('module1.module2', 'module3'). Then, output_files should
include module1/module2/__init__.py and module3/__init__.py.
- compat_output_files: Dictionary mapping each compat_api_version to the
- set of __init__.py file paths that should be generated for that version.
root_init_template: Python init file that should be used as template for
root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
template will be replaced with root imports collected by this genrule.
@@ -53,14 +57,16 @@ def gen_api_init_files(
process
package_dep: Python library target containing your package.
output_package: Package where generated API will be added to.
+ output_dir: Subdirectory to output API to.
+ If non-empty, must end with '/'.
"""
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
- api_gen_binary_target = "create_" + package + "_api"
+ api_gen_binary_target = ("create_" + package + "_api_%d") % api_version
native.py_binary(
- name = "create_" + package + "_api",
+ name = api_gen_binary_target,
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
main = "//tensorflow/python/tools/api/generator:create_python_api.py",
srcs_version = "PY2AND3",
@@ -72,14 +78,9 @@ def gen_api_init_files(
],
)
- all_output_files = list(output_files)
+ all_output_files = ["%s%s" % (output_dir, f) for f in output_files]
compat_api_version_flags = ""
for compat_api_version in compat_api_versions:
- compat_files = compat_output_files.get(compat_api_version, [])
- all_output_files.extend([
- "compat/v%d/%s" % (compat_api_version, f)
- for f in compat_files
- ])
compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version
native.genrule(
@@ -87,12 +88,15 @@ def gen_api_init_files(
outs = all_output_files,
cmd = (
"$(location :" + api_gen_binary_target + ") " +
- root_init_template_flag + " --apidir=$(@D) --apiname=" +
- api_name + " --apiversion=" + str(api_version) +
+ root_init_template_flag + " --apidir=$(@D)" + output_dir +
+ " --apiname=" + api_name + " --apiversion=" + str(api_version) +
compat_api_version_flags + " --package=" + package +
" --output_package=" + output_package + " $(OUTS)"
),
srcs = srcs,
tools = [":" + api_gen_binary_target],
- visibility = ["//tensorflow:__pkg__"],
+ visibility = [
+ "//tensorflow:__pkg__",
+ "//tensorflow/tools/api/tests:__pkg__",
+ ],
)