diff options
author | Anna R <annarev@google.com> | 2018-09-07 12:20:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 12:24:57 -0700 |
commit | a65d3dd42122d3a58985d56118d58c5b4224f38f (patch) | |
tree | 9c7bf3a1a3dfa68d73af747b281a5ae327868332 | |
parent | 0a375d94b6fd4c3cd0bd5d0a301b3acc65b96d78 (diff) |
Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, then bazel will build TensorFlow API version 2.0. In all other cases, it would build API version 1.*.
PiperOrigin-RevId: 212016666
-rw-r--r-- | tensorflow/BUILD | 50 | ||||
-rw-r--r-- | tensorflow/api_template.__init__.py | 22 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/tools/api/generator/api_gen.bzl | 34 | ||||
-rw-r--r-- | tensorflow/tools/api/tests/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/tools/api/tests/api_compatibility_test.py | 14 |
6 files changed, 95 insertions, 31 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 2926789953..386e0096ff 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -24,6 +24,11 @@ load( "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files") +load( + "//tensorflow/python/tools/api/generator:api_init_files.bzl", + "TENSORFLOW_API_INIT_FILES", # @unused +) load( "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "TENSORFLOW_API_INIT_FILES_V1", # @unused @@ -33,6 +38,11 @@ load( "if_ngraph", ) +# @unused +TENSORFLOW_API_INIT_FILES_V2 = ( + TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -428,6 +438,13 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag specifies whether TensorFlow 2.0 API should be built instead +# of 1.* API. Note that TensorFlow 2.0 API is currently under development. +config_setting( + name = "api_version_2", + define_values = {"tf_api_version": "2"}, +) + package_group( name = "internal", packages = [ @@ -592,13 +609,39 @@ exports_files( ) gen_api_init_files( - name = "tensorflow_python_api_gen", + name = "tf_python_api_gen_v1", srcs = ["api_template.__init__.py"], api_version = 1, + output_dir = "_api/v1/", output_files = TENSORFLOW_API_INIT_FILES_V1, + output_package = "tensorflow._api.v1", + root_init_template = "api_template.__init__.py", +) + +gen_api_init_files( + name = "tf_python_api_gen_v2", + srcs = ["api_template.__init__.py"], + api_version = 2, + compat_api_versions = [1], + output_dir = "_api/v2/", + output_files = TENSORFLOW_API_INIT_FILES_V2, + output_package = "tensorflow._api.v2", root_init_template = "api_template.__init__.py", ) +genrule( + name = "root_init_gen", + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }), + outs = ["__init__.py"], + cmd = select({ + "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + }), +) + py_library( name = "tensorflow_py", srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], @@ -613,7 +656,10 @@ py_library( py_library( name = "tensorflow_py_no_contrib", - srcs = [":tensorflow_python_api_gen"], + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }) + [":root_init_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 779f65d5b1..53a72b8443 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -18,11 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os as _os + # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import try: - import os # pylint: disable=g-import-not-at-top # Add `estimator` attribute to allow access to estimator APIs via # "tf.estimator..." from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top @@ -30,9 +31,8 @@ try: # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." # style imports. from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top - __path__ += [os.path.dirname(estimator_api.__file__)] + __path__ += [_os.path.dirname(estimator_api.__file__)] del estimator_api - del os except (ImportError, AttributeError): print('tf.estimator package not installed.') @@ -45,6 +45,12 @@ del LazyLoader from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +if _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + del absolute_import del division del print_function @@ -54,6 +60,12 @@ del print_function # must come from this module. So python adds these symbols for the # resolution to succeed. # pylint: disable=undefined-variable -del python -del core +try: + del python + del core +except NameError: + # Don't fail if these modules are not available. + # For e.g. we are using this file for compat.v1 module as well and + # 'python', 'core' directories are not under compat/v1. + pass # pylint: enable=undefined-variable diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index ba9c6a2320..19729813a1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -78,6 +78,7 @@ py_library( "//tensorflow:__pkg__", "//tensorflow/python/tools:__pkg__", "//tensorflow/python/tools/api/generator:__pkg__", + "//tensorflow/tools/api/tests:__pkg__", ], deps = [ ":array_ops", diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl index 2810d83bd2..271cf2afaf 100644 --- a/tensorflow/python/tools/api/generator/api_gen.bzl +++ b/tensorflow/python/tools/api/generator/api_gen.bzl @@ -12,10 +12,15 @@ ESTIMATOR_API_INIT_FILES = [ # END GENERATED ESTIMATOR FILES ] +def get_compat_files( + file_paths, + compat_api_version): + """Prepends compat/v<compat_api_version> to file_paths.""" + return ["compat/v%d/%s" % (compat_api_version, f) for f in file_paths] + def gen_api_init_files( name, output_files = TENSORFLOW_API_INIT_FILES, - compat_output_files = {}, root_init_template = None, srcs = [], api_name = "tensorflow", @@ -23,7 +28,8 @@ def gen_api_init_files( compat_api_versions = [], package = "tensorflow.python", package_dep = "//tensorflow/python:no_contrib", - output_package = "tensorflow"): + output_package = "tensorflow", + output_dir = ""): """Creates API directory structure and __init__.py files. Creates a genrule that generates a directory structure with __init__.py @@ -37,8 +43,6 @@ def gen_api_init_files( tf_export. For e.g. if an op is decorated with @tf_export('module1.module2', 'module3'). Then, output_files should include module1/module2/__init__.py and module3/__init__.py. - compat_output_files: Dictionary mapping each compat_api_version to the - set of __init__.py file paths that should be generated for that version. root_init_template: Python init file that should be used as template for root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this template will be replaced with root imports collected by this genrule. @@ -53,14 +57,16 @@ def gen_api_init_files( process package_dep: Python library target containing your package. output_package: Package where generated API will be added to. + output_dir: Subdirectory to output API to. + If non-empty, must end with '/'. """ root_init_template_flag = "" if root_init_template: root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" - api_gen_binary_target = "create_" + package + "_api" + api_gen_binary_target = ("create_" + package + "_api_%d") % api_version native.py_binary( - name = "create_" + package + "_api", + name = api_gen_binary_target, srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"], main = "//tensorflow/python/tools/api/generator:create_python_api.py", srcs_version = "PY2AND3", @@ -72,14 +78,9 @@ def gen_api_init_files( ], ) - all_output_files = list(output_files) + all_output_files = ["%s%s" % (output_dir, f) for f in output_files] compat_api_version_flags = "" for compat_api_version in compat_api_versions: - compat_files = compat_output_files.get(compat_api_version, []) - all_output_files.extend([ - "compat/v%d/%s" % (compat_api_version, f) - for f in compat_files - ]) compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version native.genrule( @@ -87,12 +88,15 @@ def gen_api_init_files( outs = all_output_files, cmd = ( "$(location :" + api_gen_binary_target + ") " + - root_init_template_flag + " --apidir=$(@D) --apiname=" + - api_name + " --apiversion=" + str(api_version) + + root_init_template_flag + " --apidir=$(@D)" + output_dir + + " --apiname=" + api_name + " --apiversion=" + str(api_version) + compat_api_version_flags + " --package=" + package + " --output_package=" + output_package + " $(OUTS)" ), srcs = srcs, tools = [":" + api_gen_binary_target], - visibility = ["//tensorflow:__pkg__"], + visibility = [ + "//tensorflow:__pkg__", + "//tensorflow/tools/api/tests:__pkg__", + ], ) diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 8764409e4d..4efa4a9651 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -15,7 +15,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") py_test( name = "api_compatibility_test", - srcs = ["api_compatibility_test.py"], + srcs = [ + "api_compatibility_test.py", + "//tensorflow:tf_python_api_gen_v2", + ], data = [ "//tensorflow/tools/api/golden:api_golden_v1", "//tensorflow/tools/api/golden:api_golden_v2", diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 43d19bc99c..99bed5714f 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -34,6 +34,7 @@ import sys import unittest import tensorflow as tf +from tensorflow._api import v2 as tf_v2 from google.protobuf import message from google.protobuf import text_format @@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf.compat.v1, visitor) + traverse.traverse(tf_v2.compat.v1, visitor) def testNoSubclassOfMessageV2(self): if not hasattr(tf.compat, 'v2'): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf.compat.v2, visitor) + traverse.traverse(tf_v2, visitor) def _checkBackwardsCompatibility( self, root, golden_file_pattern, api_version, @@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase): sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') def testAPIBackwardsCompatibilityV1(self): - if not hasattr(tf.compat, 'v1'): - return api_version = 1 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) self._checkBackwardsCompatibility( - tf.compat.v1, golden_file_pattern, api_version) + tf_v2.compat.v1, golden_file_pattern, api_version) @unittest.skipUnless( sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') def testAPIBackwardsCompatibilityV2(self): - if not hasattr(tf.compat, 'v2'): - return api_version = 2 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) self._checkBackwardsCompatibility( - tf.compat.v2, golden_file_pattern, api_version) + tf_v2, golden_file_pattern, api_version, + additional_private_map={'tf.compat': ['v1']}) if __name__ == '__main__': |