aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-07-12 14:46:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-12 14:51:03 -0700
commitda798407b4ff72f1daa629e054ccd47b162c9d58 (patch)
tree4bc5251f66dd8bb601d73fd3ec8f035b953bbe6a /tensorflow/python/tools
parentc5e563e57feee793499fae9c3ce28f5176404749 (diff)
Support passing TensorFlow API names as a separate v1 argument to tf_export.
PiperOrigin-RevId: 204368026
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/api/generator/api_gen.bzl52
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api.py33
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api_test.py6
3 files changed, 56 insertions, 35 deletions
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
index f9170610b9..2a32e8a893 100644
--- a/tensorflow/python/tools/api/generator/api_gen.bzl
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -102,37 +102,41 @@ ESTIMATOR_API_INIT_FILES = [
# END GENERATED ESTIMATOR FILES
]
-# Creates a genrule that generates a directory structure with __init__.py
-# files that import all exported modules (i.e. modules with tf_export
-# decorators).
-#
-# Args:
-# name: name of genrule to create.
-# output_files: List of __init__.py files that should be generated.
-# This list should include file name for every module exported using
-# tf_export. For e.g. if an op is decorated with
-# @tf_export('module1.module2', 'module3'). Then, output_files should
-# include module1/module2/__init__.py and module3/__init__.py.
-# root_init_template: Python init file that should be used as template for
-# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
-# template will be replaced with root imports collected by this genrule.
-# srcs: genrule sources. If passing root_init_template, the template file
-# must be included in sources.
-# api_name: Name of the project that you want to generate API files for
-# (e.g. "tensorflow" or "estimator").
-# package: Python package containing the @tf_export decorators you want to
-# process
-# package_dep: Python library target containing your package.
-
def gen_api_init_files(
name,
output_files = TENSORFLOW_API_INIT_FILES,
root_init_template = None,
srcs = [],
api_name = "tensorflow",
+ api_version = 2,
package = "tensorflow.python",
package_dep = "//tensorflow/python:no_contrib",
output_package = "tensorflow"):
+ """Creates API directory structure and __init__.py files.
+
+ Creates a genrule that generates a directory structure with __init__.py
+ files that import all exported modules (i.e. modules with tf_export
+ decorators).
+
+ Args:
+ name: name of genrule to create.
+ output_files: List of __init__.py files that should be generated.
+ This list should include file name for every module exported using
+ tf_export. For e.g. if an op is decorated with
+ @tf_export('module1.module2', 'module3'). Then, output_files should
+ include module1/module2/__init__.py and module3/__init__.py.
+ root_init_template: Python init file that should be used as template for
+ root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
+ template will be replaced with root imports collected by this genrule.
+ srcs: genrule sources. If passing root_init_template, the template file
+ must be included in sources.
+ api_name: Name of the project that you want to generate API files for
+ (e.g. "tensorflow" or "estimator").
+ api_version: TensorFlow API version to generate. Must be either 1 or 2.
+ package: Python package containing the @tf_export decorators you want to
+ process
+ package_dep: Python library target containing your package.
+ """
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
@@ -156,8 +160,8 @@ def gen_api_init_files(
cmd = (
"$(location :" + api_gen_binary_target + ") " +
root_init_template_flag + " --apidir=$(@D) --apiname=" +
- api_name + " --package=" + package + " --output_package=" +
- output_package + " $(OUTS)"),
+ api_name + " --apiversion=" + str(api_version) + " --package=" + package +
+ " --output_package=" + output_package + " $(OUTS)"),
srcs = srcs,
tools = [":" + api_gen_binary_target ],
visibility = ["//tensorflow:__pkg__"],
diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py
index e78fe4b738..863c922216 100644
--- a/tensorflow/python/tools/api/generator/create_python_api.py
+++ b/tensorflow/python/tools/api/generator/create_python_api.py
@@ -29,6 +29,7 @@ from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export
API_ATTRS = tf_export.API_ATTRS
+API_ATTRS_V1 = tf_export.API_ATTRS_V1
_DEFAULT_PACKAGE = 'tensorflow.python'
_GENFILES_DIR_SUFFIX = 'genfiles/'
@@ -159,13 +160,16 @@ __all__.remove('print_function')
return module_text_map
-def get_api_init_text(package, output_package, api_name):
+def get_api_init_text(package, output_package, api_name, api_version):
"""Get a map from destination module to __init__.py code for that module.
Args:
package: Base python package containing python with target tf_export
decorators.
+ output_package: Base output python package where generated API will
+ be added.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+ api_version: API version you want to generate (`v1` or `v2`).
Returns:
A dictionary where
@@ -173,6 +177,12 @@ def get_api_init_text(package, output_package, api_name):
value: (string) text that should be in __init__.py files for
corresponding modules.
"""
+ if api_version == 1:
+ names_attr = API_ATTRS_V1[api_name].names
+ constants_attr = API_ATTRS_V1[api_name].constants
+ else:
+ names_attr = API_ATTRS[api_name].names
+ constants_attr = API_ATTRS[api_name].constants
module_code_builder = _ModuleInitCodeBuilder()
# Traverse over everything imported above. Specifically,
@@ -193,7 +203,7 @@ def get_api_init_text(package, output_package, api_name):
attr = getattr(module, module_contents_name)
# If attr is _tf_api_constants attribute, then add the constants.
- if module_contents_name == API_ATTRS[api_name].constants:
+ if module_contents_name == constants_attr:
for exports, value in attr:
for export in exports:
names = export.split('.')
@@ -205,9 +215,8 @@ def get_api_init_text(package, output_package, api_name):
_, attr = tf_decorator.unwrap(attr)
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
- if (hasattr(attr, '__dict__') and
- API_ATTRS[api_name].names in attr.__dict__):
- for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access
+ if (hasattr(attr, '__dict__') and names_attr in attr.__dict__):
+ for export in getattr(attr, names_attr): # pylint: disable=protected-access
names = export.split('.')
dest_module = '.'.join(names[:-1])
module_code_builder.add_import(
@@ -297,7 +306,7 @@ def get_module_docstring(module_name, package, api_name):
def create_api_files(
output_files, package, root_init_template, output_dir, output_package,
- api_name):
+ api_name, api_version):
"""Creates __init__.py files for the Python API.
Args:
@@ -309,7 +318,9 @@ def create_api_files(
"#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
with imports.
output_dir: output API root directory.
+ output_package: Base output package where generated API will be added.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+ api_version: API version to generate (`v1` or `v2`).
Raises:
ValueError: if an output file is not under api/ directory,
@@ -326,7 +337,8 @@ def create_api_files(
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_text_map = get_api_init_text(package, output_package, api_name)
+ module_text_map = get_api_init_text(
+ package, output_package, api_name, api_version)
# Add imports to output files.
missing_output_files = []
@@ -385,6 +397,10 @@ def main():
choices=API_ATTRS.keys(),
help='The API you want to generate.')
parser.add_argument(
+ '--apiversion', default=2, type=int,
+ choices=[1, 2],
+ help='The API version you want to generate.')
+ parser.add_argument(
'--output_package', default='tensorflow', type=str,
help='Root output package.')
@@ -401,7 +417,8 @@ def main():
# Populate `sys.modules` with modules containing tf_export().
importlib.import_module(args.package)
create_api_files(outputs, args.package, args.root_init_template,
- args.apidir, args.output_package, args.apiname)
+ args.apidir, args.output_package, args.apiname,
+ args.apiversion)
if __name__ == '__main__':
diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py
index 368b4c37e8..a565a49d96 100644
--- a/tensorflow/python/tools/api/generator/create_python_api_test.py
+++ b/tensorflow/python/tools/api/generator/create_python_api_test.py
@@ -59,7 +59,7 @@ class CreatePythonApiTest(test.TestCase):
imports = create_python_api.get_api_init_text(
package=create_python_api._DEFAULT_PACKAGE,
output_package='tensorflow',
- api_name='tensorflow')
+ api_name='tensorflow', api_version=1)
expected_import = (
'from tensorflow.python.test_module '
'import test_op as test_op1')
@@ -77,7 +77,7 @@ class CreatePythonApiTest(test.TestCase):
imports = create_python_api.get_api_init_text(
package=create_python_api._DEFAULT_PACKAGE,
output_package='tensorflow',
- api_name='tensorflow')
+ api_name='tensorflow', api_version=2)
expected_import = ('from tensorflow.python.test_module '
'import TestClass')
self.assertTrue(
@@ -88,7 +88,7 @@ class CreatePythonApiTest(test.TestCase):
imports = create_python_api.get_api_init_text(
package=create_python_api._DEFAULT_PACKAGE,
output_package='tensorflow',
- api_name='tensorflow')
+ api_name='tensorflow', api_version=1)
expected = ('from tensorflow.python.test_module '
'import _TEST_CONSTANT')
self.assertTrue(expected in str(imports),