aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-09-07 12:20:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 12:24:57 -0700
commita65d3dd42122d3a58985d56118d58c5b4224f38f (patch)
tree9c7bf3a1a3dfa68d73af747b281a5ae327868332 /tensorflow
parent0a375d94b6fd4c3cd0bd5d0a301b3acc65b96d78 (diff)
Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, then bazel will build TensorFlow API version 2.0. In all other cases, it would build API version 1.*.
PiperOrigin-RevId: 212016666
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD50
-rw-r--r--tensorflow/api_template.__init__.py22
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/tools/api/generator/api_gen.bzl34
-rw-r--r--tensorflow/tools/api/tests/BUILD5
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py14
6 files changed, 95 insertions, 31 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 2926789953..386e0096ff 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -24,6 +24,11 @@ load(
"//tensorflow/python/tools/api/generator:api_gen.bzl",
"gen_api_init_files", # @unused
)
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files")
+load(
+ "//tensorflow/python/tools/api/generator:api_init_files.bzl",
+ "TENSORFLOW_API_INIT_FILES", # @unused
+)
load(
"//tensorflow/python/tools/api/generator:api_init_files_v1.bzl",
"TENSORFLOW_API_INIT_FILES_V1", # @unused
@@ -33,6 +38,11 @@ load(
"if_ngraph",
)
+# @unused
+TENSORFLOW_API_INIT_FILES_V2 = (
+ TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
+)
+
# Config setting used when building for products
# which requires restricted licenses to be avoided.
config_setting(
@@ -428,6 +438,13 @@ config_setting(
visibility = ["//visibility:public"],
)
+# This flag specifies whether TensorFlow 2.0 API should be built instead
+# of 1.* API. Note that TensorFlow 2.0 API is currently under development.
+config_setting(
+ name = "api_version_2",
+ define_values = {"tf_api_version": "2"},
+)
+
package_group(
name = "internal",
packages = [
@@ -592,13 +609,39 @@ exports_files(
)
gen_api_init_files(
- name = "tensorflow_python_api_gen",
+ name = "tf_python_api_gen_v1",
srcs = ["api_template.__init__.py"],
api_version = 1,
+ output_dir = "_api/v1/",
output_files = TENSORFLOW_API_INIT_FILES_V1,
+ output_package = "tensorflow._api.v1",
+ root_init_template = "api_template.__init__.py",
+)
+
+gen_api_init_files(
+ name = "tf_python_api_gen_v2",
+ srcs = ["api_template.__init__.py"],
+ api_version = 2,
+ compat_api_versions = [1],
+ output_dir = "_api/v2/",
+ output_files = TENSORFLOW_API_INIT_FILES_V2,
+ output_package = "tensorflow._api.v2",
root_init_template = "api_template.__init__.py",
)
+genrule(
+ name = "root_init_gen",
+ srcs = select({
+ "api_version_2": [":tf_python_api_gen_v2"],
+ "//conditions:default": [":tf_python_api_gen_v1"],
+ }),
+ outs = ["__init__.py"],
+ cmd = select({
+ "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
+ "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+ }),
+)
+
py_library(
name = "tensorflow_py",
srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
@@ -613,7 +656,10 @@ py_library(
py_library(
name = "tensorflow_py_no_contrib",
- srcs = [":tensorflow_python_api_gen"],
+ srcs = select({
+ "api_version_2": [":tf_python_api_gen_v2"],
+ "//conditions:default": [":tf_python_api_gen_v1"],
+ }) + [":root_init_gen"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python:no_contrib"],
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 779f65d5b1..53a72b8443 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -18,11 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os as _os
+
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
try:
- import os # pylint: disable=g-import-not-at-top
# Add `estimator` attribute to allow access to estimator APIs via
# "tf.estimator..."
from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top
@@ -30,9 +31,8 @@ try:
# Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
# style imports.
from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top
- __path__ += [os.path.dirname(estimator_api.__file__)]
+ __path__ += [_os.path.dirname(estimator_api.__file__)]
del estimator_api
- del os
except (ImportError, AttributeError):
print('tf.estimator package not installed.')
@@ -45,6 +45,12 @@ del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
+# Make sure directory containing top level submodules is in
+# the __path__ so that "from tensorflow.foo import bar" works.
+_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable
+if _tf_api_dir not in __path__:
+ __path__.append(_tf_api_dir)
+
del absolute_import
del division
del print_function
@@ -54,6 +60,12 @@ del print_function
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
-del python
-del core
+try:
+ del python
+ del core
+except NameError:
+ # Don't fail if these modules are not available.
+ # For e.g. we are using this file for compat.v1 module as well and
+ # 'python', 'core' directories are not under compat/v1.
+ pass
# pylint: enable=undefined-variable
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index ba9c6a2320..19729813a1 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -78,6 +78,7 @@ py_library(
"//tensorflow:__pkg__",
"//tensorflow/python/tools:__pkg__",
"//tensorflow/python/tools/api/generator:__pkg__",
+ "//tensorflow/tools/api/tests:__pkg__",
],
deps = [
":array_ops",
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
index 2810d83bd2..271cf2afaf 100644
--- a/tensorflow/python/tools/api/generator/api_gen.bzl
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -12,10 +12,15 @@ ESTIMATOR_API_INIT_FILES = [
# END GENERATED ESTIMATOR FILES
]
+def get_compat_files(
+ file_paths,
+ compat_api_version):
+ """Prepends compat/v<compat_api_version> to file_paths."""
+ return ["compat/v%d/%s" % (compat_api_version, f) for f in file_paths]
+
def gen_api_init_files(
name,
output_files = TENSORFLOW_API_INIT_FILES,
- compat_output_files = {},
root_init_template = None,
srcs = [],
api_name = "tensorflow",
@@ -23,7 +28,8 @@ def gen_api_init_files(
compat_api_versions = [],
package = "tensorflow.python",
package_dep = "//tensorflow/python:no_contrib",
- output_package = "tensorflow"):
+ output_package = "tensorflow",
+ output_dir = ""):
"""Creates API directory structure and __init__.py files.
Creates a genrule that generates a directory structure with __init__.py
@@ -37,8 +43,6 @@ def gen_api_init_files(
tf_export. For e.g. if an op is decorated with
@tf_export('module1.module2', 'module3'). Then, output_files should
include module1/module2/__init__.py and module3/__init__.py.
- compat_output_files: Dictionary mapping each compat_api_version to the
- set of __init__.py file paths that should be generated for that version.
root_init_template: Python init file that should be used as template for
root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
template will be replaced with root imports collected by this genrule.
@@ -53,14 +57,16 @@ def gen_api_init_files(
process
package_dep: Python library target containing your package.
output_package: Package where generated API will be added to.
+ output_dir: Subdirectory to output API to.
+ If non-empty, must end with '/'.
"""
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
- api_gen_binary_target = "create_" + package + "_api"
+ api_gen_binary_target = ("create_" + package + "_api_%d") % api_version
native.py_binary(
- name = "create_" + package + "_api",
+ name = api_gen_binary_target,
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
main = "//tensorflow/python/tools/api/generator:create_python_api.py",
srcs_version = "PY2AND3",
@@ -72,14 +78,9 @@ def gen_api_init_files(
],
)
- all_output_files = list(output_files)
+ all_output_files = ["%s%s" % (output_dir, f) for f in output_files]
compat_api_version_flags = ""
for compat_api_version in compat_api_versions:
- compat_files = compat_output_files.get(compat_api_version, [])
- all_output_files.extend([
- "compat/v%d/%s" % (compat_api_version, f)
- for f in compat_files
- ])
compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version
native.genrule(
@@ -87,12 +88,15 @@ def gen_api_init_files(
outs = all_output_files,
cmd = (
"$(location :" + api_gen_binary_target + ") " +
- root_init_template_flag + " --apidir=$(@D) --apiname=" +
- api_name + " --apiversion=" + str(api_version) +
+ root_init_template_flag + " --apidir=$(@D)" + output_dir +
+ " --apiname=" + api_name + " --apiversion=" + str(api_version) +
compat_api_version_flags + " --package=" + package +
" --output_package=" + output_package + " $(OUTS)"
),
srcs = srcs,
tools = [":" + api_gen_binary_target],
- visibility = ["//tensorflow:__pkg__"],
+ visibility = [
+ "//tensorflow:__pkg__",
+ "//tensorflow/tools/api/tests:__pkg__",
+ ],
)
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 8764409e4d..4efa4a9651 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -15,7 +15,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
py_test(
name = "api_compatibility_test",
- srcs = ["api_compatibility_test.py"],
+ srcs = [
+ "api_compatibility_test.py",
+ "//tensorflow:tf_python_api_gen_v2",
+ ],
data = [
"//tensorflow/tools/api/golden:api_golden_v1",
"//tensorflow/tools/api/golden:api_golden_v2",
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 43d19bc99c..99bed5714f 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,6 +34,7 @@ import sys
import unittest
import tensorflow as tf
+from tensorflow._api import v2 as tf_v2
from google.protobuf import message
from google.protobuf import text_format
@@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v1, visitor)
+ traverse.traverse(tf_v2.compat.v1, visitor)
def testNoSubclassOfMessageV2(self):
if not hasattr(tf.compat, 'v2'):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v2, visitor)
+ traverse.traverse(tf_v2, visitor)
def _checkBackwardsCompatibility(
self, root, golden_file_pattern, api_version,
@@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase):
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV1(self):
- if not hasattr(tf.compat, 'v1'):
- return
api_version = 1
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v1, golden_file_pattern, api_version)
+ tf_v2.compat.v1, golden_file_pattern, api_version)
@unittest.skipUnless(
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV2(self):
- if not hasattr(tf.compat, 'v2'):
- return
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v2, golden_file_pattern, api_version)
+ tf_v2, golden_file_pattern, api_version,
+ additional_private_map={'tf.compat': ['v1']})
if __name__ == '__main__':