aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-07-12 11:56:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-12 12:04:28 -0700
commitc35bd2e9d3d9311bc7fb0f2463869faf1a8a7b50 (patch)
treeb0aebeae9cb451bb0938d997c20cd78738b7ae5f /tensorflow/python/tools
parent0678f10d0f96b46ecabf129cd69a04de2df49a3d (diff)
Internal Change.
PiperOrigin-RevId: 204338153
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/api/generator/BUILD84
-rw-r--r--tensorflow/python/tools/api/generator/api_gen.bzl164
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api.py408
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api_test.py99
-rw-r--r--tensorflow/python/tools/api/generator/doc_srcs.py92
-rw-r--r--tensorflow/python/tools/api/generator/doc_srcs_test.py83
6 files changed, 930 insertions, 0 deletions
diff --git a/tensorflow/python/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD
new file mode 100644
index 0000000000..223d1281ba
--- /dev/null
+++ b/tensorflow/python/tools/api/generator/BUILD
@@ -0,0 +1,84 @@
+# Description:
+# Scripts used to generate TensorFlow Python API.
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
+
+exports_files(
+ [
+ "LICENSE",
+ "create_python_api.py",
+ ],
+)
+
+py_binary(
+ name = "create_python_api",
+ srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
+ main = "//tensorflow/python/tools/api/generator:create_python_api.py",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:no_contrib",
+ "//tensorflow/python/tools/api/generator:doc_srcs",
+ ],
+)
+
+py_library(
+ name = "doc_srcs",
+ srcs = ["doc_srcs.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:util",
+ ],
+)
+
+py_test(
+ name = "create_python_api_test",
+ srcs = [
+ "create_python_api.py",
+ "create_python_api_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":doc_srcs",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:no_contrib",
+ ],
+)
+
+py_test(
+ name = "tensorflow_doc_srcs_test",
+ srcs = ["doc_srcs_test.py"],
+ args = [
+ "--package=tensorflow.python",
+ "--api_name=tensorflow",
+ ] + TENSORFLOW_API_INIT_FILES,
+ main = "doc_srcs_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":doc_srcs",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:no_contrib",
+ ],
+)
+
+py_test(
+ name = "estimator_doc_srcs_test",
+ srcs = ["doc_srcs_test.py"],
+ args = [
+ "--package=tensorflow.python.estimator",
+ "--api_name=estimator",
+ ] + ESTIMATOR_API_INIT_FILES,
+ main = "doc_srcs_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":doc_srcs",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:no_contrib",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
new file mode 100644
index 0000000000..f9170610b9
--- /dev/null
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -0,0 +1,164 @@
+"""Targets for generating TensorFlow Python API __init__.py files."""
+
+# keep sorted
+TENSORFLOW_API_INIT_FILES = [
+ # BEGIN GENERATED FILES
+ "__init__.py",
+ "app/__init__.py",
+ "bitwise/__init__.py",
+ "compat/__init__.py",
+ "data/__init__.py",
+ "debugging/__init__.py",
+ "distributions/__init__.py",
+ "distributions/bijectors/__init__.py",
+ "dtypes/__init__.py",
+ "errors/__init__.py",
+ "feature_column/__init__.py",
+ "gfile/__init__.py",
+ "graph_util/__init__.py",
+ "image/__init__.py",
+ "io/__init__.py",
+ "initializers/__init__.py",
+ "keras/__init__.py",
+ "keras/activations/__init__.py",
+ "keras/applications/__init__.py",
+ "keras/applications/densenet/__init__.py",
+ "keras/applications/inception_resnet_v2/__init__.py",
+ "keras/applications/inception_v3/__init__.py",
+ "keras/applications/mobilenet/__init__.py",
+ "keras/applications/nasnet/__init__.py",
+ "keras/applications/resnet50/__init__.py",
+ "keras/applications/vgg16/__init__.py",
+ "keras/applications/vgg19/__init__.py",
+ "keras/applications/xception/__init__.py",
+ "keras/backend/__init__.py",
+ "keras/callbacks/__init__.py",
+ "keras/constraints/__init__.py",
+ "keras/datasets/__init__.py",
+ "keras/datasets/boston_housing/__init__.py",
+ "keras/datasets/cifar10/__init__.py",
+ "keras/datasets/cifar100/__init__.py",
+ "keras/datasets/fashion_mnist/__init__.py",
+ "keras/datasets/imdb/__init__.py",
+ "keras/datasets/mnist/__init__.py",
+ "keras/datasets/reuters/__init__.py",
+ "keras/estimator/__init__.py",
+ "keras/initializers/__init__.py",
+ "keras/layers/__init__.py",
+ "keras/losses/__init__.py",
+ "keras/metrics/__init__.py",
+ "keras/models/__init__.py",
+ "keras/optimizers/__init__.py",
+ "keras/preprocessing/__init__.py",
+ "keras/preprocessing/image/__init__.py",
+ "keras/preprocessing/sequence/__init__.py",
+ "keras/preprocessing/text/__init__.py",
+ "keras/regularizers/__init__.py",
+ "keras/utils/__init__.py",
+ "keras/wrappers/__init__.py",
+ "keras/wrappers/scikit_learn/__init__.py",
+ "layers/__init__.py",
+ "linalg/__init__.py",
+ "logging/__init__.py",
+ "losses/__init__.py",
+ "manip/__init__.py",
+ "math/__init__.py",
+ "metrics/__init__.py",
+ "nn/__init__.py",
+ "nn/rnn_cell/__init__.py",
+ "profiler/__init__.py",
+ "python_io/__init__.py",
+ "quantization/__init__.py",
+ "resource_loader/__init__.py",
+ "strings/__init__.py",
+ "saved_model/__init__.py",
+ "saved_model/builder/__init__.py",
+ "saved_model/constants/__init__.py",
+ "saved_model/loader/__init__.py",
+ "saved_model/main_op/__init__.py",
+ "saved_model/signature_constants/__init__.py",
+ "saved_model/signature_def_utils/__init__.py",
+ "saved_model/tag_constants/__init__.py",
+ "saved_model/utils/__init__.py",
+ "sets/__init__.py",
+ "sparse/__init__.py",
+ "spectral/__init__.py",
+ "summary/__init__.py",
+ "sysconfig/__init__.py",
+ "test/__init__.py",
+ "train/__init__.py",
+ "train/queue_runner/__init__.py",
+ "user_ops/__init__.py",
+ # END GENERATED FILES
+]
+
+# keep sorted
+ESTIMATOR_API_INIT_FILES = [
+ # BEGIN GENERATED ESTIMATOR FILES
+ "__init__.py",
+ "estimator/__init__.py",
+ "estimator/export/__init__.py",
+ "estimator/inputs/__init__.py",
+ # END GENERATED ESTIMATOR FILES
+]
+
+# Creates a genrule that generates a directory structure with __init__.py
+# files that import all exported modules (i.e. modules with tf_export
+# decorators).
+#
+# Args:
+# name: name of genrule to create.
+# output_files: List of __init__.py files that should be generated.
+# This list should include file name for every module exported using
+# tf_export. For e.g. if an op is decorated with
+# @tf_export('module1.module2', 'module3'). Then, output_files should
+# include module1/module2/__init__.py and module3/__init__.py.
+# root_init_template: Python init file that should be used as template for
+# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
+# template will be replaced with root imports collected by this genrule.
+# srcs: genrule sources. If passing root_init_template, the template file
+# must be included in sources.
+# api_name: Name of the project that you want to generate API files for
+# (e.g. "tensorflow" or "estimator").
+# package: Python package containing the @tf_export decorators you want to
+# process
+# package_dep: Python library target containing your package.
+
+def gen_api_init_files(
+ name,
+ output_files = TENSORFLOW_API_INIT_FILES,
+ root_init_template = None,
+ srcs = [],
+ api_name = "tensorflow",
+ package = "tensorflow.python",
+ package_dep = "//tensorflow/python:no_contrib",
+ output_package = "tensorflow"):
+ root_init_template_flag = ""
+ if root_init_template:
+ root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
+
+ api_gen_binary_target = "create_" + package + "_api"
+ native.py_binary(
+ name = "create_" + package + "_api",
+ srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
+ main = "//tensorflow/python/tools/api/generator:create_python_api.py",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ package_dep,
+ "//tensorflow/python/tools/api/generator:doc_srcs",
+ ],
+ )
+
+ native.genrule(
+ name = name,
+ outs = output_files,
+ cmd = (
+ "$(location :" + api_gen_binary_target + ") " +
+ root_init_template_flag + " --apidir=$(@D) --apiname=" +
+ api_name + " --package=" + package + " --output_package=" +
+ output_package + " $(OUTS)"),
+ srcs = srcs,
+ tools = [":" + api_gen_binary_target ],
+ visibility = ["//tensorflow:__pkg__"],
+ )
diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py
new file mode 100644
index 0000000000..e78fe4b738
--- /dev/null
+++ b/tensorflow/python/tools/api/generator/create_python_api.py
@@ -0,0 +1,408 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Generates and prints out imports and constants for new TensorFlow python api.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import collections
+import importlib
+import os
+import sys
+
+from tensorflow.python.tools.api.generator import doc_srcs
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_export
+
+API_ATTRS = tf_export.API_ATTRS
+
+_DEFAULT_PACKAGE = 'tensorflow.python'
+_GENFILES_DIR_SUFFIX = 'genfiles/'
+_SYMBOLS_TO_SKIP_EXPLICITLY = {
+ # Overrides __getattr__, so that unwrapping tf_decorator
+ # would have side effects.
+ 'tensorflow.python.platform.flags.FLAGS'
+}
+_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
+# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
+\"\"\"%s
+\"\"\"
+
+from __future__ import print_function
+
+"""
+_GENERATED_FILE_FOOTER = '\n\ndel print_function\n'
+
+
+class SymbolExposedTwiceError(Exception):
+ """Raised when different symbols are exported with the same name."""
+ pass
+
+
+def format_import(source_module_name, source_name, dest_name):
+ """Formats import statement.
+
+ Args:
+ source_module_name: (string) Source module to import from.
+ source_name: (string) Source symbol name to import.
+ dest_name: (string) Destination alias name.
+
+ Returns:
+ An import statement string.
+ """
+ if source_module_name:
+ if source_name == dest_name:
+ return 'from %s import %s' % (source_module_name, source_name)
+ else:
+ return 'from %s import %s as %s' % (
+ source_module_name, source_name, dest_name)
+ else:
+ if source_name == dest_name:
+ return 'import %s' % source_name
+ else:
+ return 'import %s as %s' % (source_name, dest_name)
+
+
+class _ModuleInitCodeBuilder(object):
+ """Builds a map from module name to imports included in that module."""
+
+ def __init__(self):
+ self.module_imports = collections.defaultdict(
+ lambda: collections.defaultdict(set))
+ self._dest_import_to_id = collections.defaultdict(int)
+ # Names that start with underscore in the root module.
+ self._underscore_names_in_root = []
+
+ def add_import(
+ self, symbol_id, dest_module_name, source_module_name, source_name,
+ dest_name):
+ """Adds this import to module_imports.
+
+ Args:
+ symbol_id: (number) Unique identifier of the symbol to import.
+ dest_module_name: (string) Module name to add import to.
+ source_module_name: (string) Module to import from.
+ source_name: (string) Name of the symbol to import.
+ dest_name: (string) Import the symbol using this name.
+
+ Raises:
+ SymbolExposedTwiceError: Raised when an import with the same
+ dest_name has already been added to dest_module_name.
+ """
+ import_str = format_import(source_module_name, source_name, dest_name)
+
+ # Check if we are trying to expose two different symbols with same name.
+ full_api_name = dest_name
+ if dest_module_name:
+ full_api_name = dest_module_name + '.' + full_api_name
+ if (full_api_name in self._dest_import_to_id and
+ symbol_id != self._dest_import_to_id[full_api_name] and
+ symbol_id != -1):
+ raise SymbolExposedTwiceError(
+ 'Trying to export multiple symbols with same name: %s.' %
+ full_api_name)
+ self._dest_import_to_id[full_api_name] = symbol_id
+
+ if not dest_module_name and dest_name.startswith('_'):
+ self._underscore_names_in_root.append(dest_name)
+
+ # The same symbol can be available in multiple modules.
+ # We store all possible ways of importing this symbol and later pick just
+ # one.
+ self.module_imports[dest_module_name][full_api_name].add(import_str)
+
+ def build(self):
+ """Get a map from destination module to __init__.py code for that module.
+
+ Returns:
+ A dictionary where
+ key: (string) destination module (for e.g. tf or tf.consts).
+ value: (string) text that should be in __init__.py files for
+ corresponding modules.
+ """
+ module_text_map = {}
+ for dest_module, dest_name_to_imports in self.module_imports.items():
+ # Sort all possible imports for a symbol and pick the first one.
+ imports_list = [
+ sorted(imports)[0]
+ for _, imports in dest_name_to_imports.items()]
+ module_text_map[dest_module] = '\n'.join(sorted(imports_list))
+
+ # Expose exported symbols with underscores in root module
+ # since we import from it using * import.
+ underscore_names_str = ', '.join(
+ '\'%s\'' % name for name in self._underscore_names_in_root)
+ # We will always generate a root __init__.py file to let us handle *
+ # imports consistently. Be sure to have a root __init__.py file listed in
+ # the script outputs.
+ module_text_map[''] = module_text_map.get('', '') + '''
+_names_with_underscore = [%s]
+__all__ = [_s for _s in dir() if not _s.startswith('_')]
+__all__.extend([_s for _s in _names_with_underscore])
+__all__.remove('print_function')
+''' % underscore_names_str
+
+ return module_text_map
+
+
+def get_api_init_text(package, output_package, api_name):
+ """Get a map from destination module to __init__.py code for that module.
+
+ Args:
+ package: Base python package containing python with target tf_export
+ decorators.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+
+ Returns:
+ A dictionary where
+ key: (string) destination module (for e.g. tf or tf.consts).
+ value: (string) text that should be in __init__.py files for
+ corresponding modules.
+ """
+ module_code_builder = _ModuleInitCodeBuilder()
+
+ # Traverse over everything imported above. Specifically,
+ # we want to traverse over TensorFlow Python modules.
+ for module in list(sys.modules.values()):
+ # Only look at tensorflow modules.
+ if (not module or not hasattr(module, '__name__') or
+ module.__name__ is None or package not in module.__name__):
+ continue
+ # Do not generate __init__.py files for contrib modules for now.
+ if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
+ continue
+
+ for module_contents_name in dir(module):
+ if (module.__name__ + '.' + module_contents_name
+ in _SYMBOLS_TO_SKIP_EXPLICITLY):
+ continue
+ attr = getattr(module, module_contents_name)
+
+ # If attr is _tf_api_constants attribute, then add the constants.
+ if module_contents_name == API_ATTRS[api_name].constants:
+ for exports, value in attr:
+ for export in exports:
+ names = export.split('.')
+ dest_module = '.'.join(names[:-1])
+ module_code_builder.add_import(
+ -1, dest_module, module.__name__, value, names[-1])
+ continue
+
+ _, attr = tf_decorator.unwrap(attr)
+ # If attr is a symbol with _tf_api_names attribute, then
+ # add import for it.
+ if (hasattr(attr, '__dict__') and
+ API_ATTRS[api_name].names in attr.__dict__):
+ for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access
+ names = export.split('.')
+ dest_module = '.'.join(names[:-1])
+ module_code_builder.add_import(
+ id(attr), dest_module, module.__name__, module_contents_name,
+ names[-1])
+
+ # Import all required modules in their parent modules.
+ # For e.g. if we import 'foo.bar.Value'. Then, we also
+ # import 'bar' in 'foo'.
+ imported_modules = set(module_code_builder.module_imports.keys())
+ for module in imported_modules:
+ if not module:
+ continue
+ module_split = module.split('.')
+ parent_module = '' # we import submodules in their parent_module
+
+ for submodule_index in range(len(module_split)):
+ if submodule_index > 0:
+ parent_module += ('.' + module_split[submodule_index-1] if parent_module
+ else module_split[submodule_index-1])
+ import_from = output_package
+ if submodule_index > 0:
+ import_from += '.' + '.'.join(module_split[:submodule_index])
+ module_code_builder.add_import(
+ -1, parent_module, import_from,
+ module_split[submodule_index], module_split[submodule_index])
+
+ return module_code_builder.build()
+
+
+def get_module(dir_path, relative_to_dir):
+ """Get module that corresponds to path relative to relative_to_dir.
+
+ Args:
+ dir_path: Path to directory.
+ relative_to_dir: Get module relative to this directory.
+
+ Returns:
+ Name of module that corresponds to the given directory.
+ """
+ dir_path = dir_path[len(relative_to_dir):]
+ # Convert path separators to '/' for easier parsing below.
+ dir_path = dir_path.replace(os.sep, '/')
+ return dir_path.replace('/', '.').strip('.')
+
+
+def get_module_docstring(module_name, package, api_name):
+ """Get docstring for the given module.
+
+ This method looks for docstring in the following order:
+ 1. Checks if module has a docstring specified in doc_srcs.
+ 2. Checks if module has a docstring source module specified
+ in doc_srcs. If it does, gets docstring from that module.
+ 3. Checks if module with module_name exists under base package.
+ If it does, gets docstring from that module.
+ 4. Returns a default docstring.
+
+ Args:
+ module_name: module name relative to tensorflow
+ (excluding 'tensorflow.' prefix) to get a docstring for.
+ package: Base python package containing python with target tf_export
+ decorators.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+
+ Returns:
+ One-line docstring to describe the module.
+ """
+ # Module under base package to get a docstring from.
+ docstring_module_name = module_name
+
+ doc_sources = doc_srcs.get_doc_sources(api_name)
+
+ if module_name in doc_sources:
+ docsrc = doc_sources[module_name]
+ if docsrc.docstring:
+ return docsrc.docstring
+ if docsrc.docstring_module_name:
+ docstring_module_name = docsrc.docstring_module_name
+
+ docstring_module_name = package + '.' + docstring_module_name
+ if (docstring_module_name in sys.modules and
+ sys.modules[docstring_module_name].__doc__):
+ return sys.modules[docstring_module_name].__doc__
+
+ return 'Public API for tf.%s namespace.' % module_name
+
+
+def create_api_files(
+ output_files, package, root_init_template, output_dir, output_package,
+ api_name):
+ """Creates __init__.py files for the Python API.
+
+ Args:
+ output_files: List of __init__.py file paths to create.
+ Each file must be under api/ directory.
+ package: Base python package containing python with target tf_export
+ decorators.
+ root_init_template: Template for top-level __init__.py file.
+ "#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
+ with imports.
+ output_dir: output API root directory.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+
+ Raises:
+ ValueError: if an output file is not under api/ directory,
+ or output_files list is missing a required file.
+ """
+ module_name_to_file_path = {}
+ for output_file in output_files:
+ module_name = get_module(os.path.dirname(output_file), output_dir)
+ module_name_to_file_path[module_name] = os.path.normpath(output_file)
+
+ # Create file for each expected output in genrule.
+ for module, file_path in module_name_to_file_path.items():
+ if not os.path.isdir(os.path.dirname(file_path)):
+ os.makedirs(os.path.dirname(file_path))
+ open(file_path, 'a').close()
+
+ module_text_map = get_api_init_text(package, output_package, api_name)
+
+ # Add imports to output files.
+ missing_output_files = []
+ for module, text in module_text_map.items():
+ # Make sure genrule output file list is in sync with API exports.
+ if module not in module_name_to_file_path:
+ module_file_path = '"%s/__init__.py"' % (
+ module.replace('.', '/'))
+ missing_output_files.append(module_file_path)
+ continue
+ contents = ''
+ if module or not root_init_template:
+ contents = (
+ _GENERATED_FILE_HEADER %
+ get_module_docstring(module, package, api_name) +
+ text + _GENERATED_FILE_FOOTER)
+ else:
+ # Read base init file
+ with open(root_init_template, 'r') as root_init_template_file:
+ contents = root_init_template_file.read()
+ contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
+ with open(module_name_to_file_path[module], 'w') as fp:
+ fp.write(contents)
+
+ if missing_output_files:
+ raise ValueError(
+ 'Missing outputs for python_api_gen genrule:\n%s.'
+ 'Make sure all required outputs are in the '
+ 'tensorflow/tools/api/generator/api_gen.bzl file.' %
+ ',\n'.join(sorted(missing_output_files)))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'outputs', metavar='O', type=str, nargs='+',
+ help='If a single file is passed in, then we we assume it contains a '
+ 'semicolon-separated list of Python files that we expect this script to '
+ 'output. If multiple files are passed in, then we assume output files '
+ 'are listed directly as arguments.')
+ parser.add_argument(
+ '--package', default=_DEFAULT_PACKAGE, type=str,
+ help='Base package that imports modules containing the target tf_export '
+ 'decorators.')
+ parser.add_argument(
+ '--root_init_template', default='', type=str,
+ help='Template for top level __init__.py file. '
+ '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
+ parser.add_argument(
+ '--apidir', type=str, required=True,
+ help='Directory where generated output files are placed. '
+ 'gendir should be a prefix of apidir. Also, apidir '
+ 'should be a prefix of every directory in outputs.')
+ parser.add_argument(
+ '--apiname', required=True, type=str,
+ choices=API_ATTRS.keys(),
+ help='The API you want to generate.')
+ parser.add_argument(
+ '--output_package', default='tensorflow', type=str,
+ help='Root output package.')
+
+ args = parser.parse_args()
+
+ if len(args.outputs) == 1:
+ # If we only get a single argument, then it must be a file containing
+ # list of outputs.
+ with open(args.outputs[0]) as output_list_file:
+ outputs = [line.strip() for line in output_list_file.read().split(';')]
+ else:
+ outputs = args.outputs
+
+ # Populate `sys.modules` with modules containing tf_export().
+ importlib.import_module(args.package)
+ create_api_files(outputs, args.package, args.root_init_template,
+ args.apidir, args.output_package, args.apiname)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py
new file mode 100644
index 0000000000..368b4c37e8
--- /dev/null
+++ b/tensorflow/python/tools/api/generator/create_python_api_test.py
@@ -0,0 +1,99 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for create_python_api."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import imp
+import sys
+
+from tensorflow.python.platform import test
+from tensorflow.python.tools.api.generator import create_python_api
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('test_op', 'test_op1')
+def test_op():
+ pass
+
+
+@tf_export('TestClass', 'NewTestClass')
+class TestClass(object):
+ pass
+
+
+_TEST_CONSTANT = 5
+_MODULE_NAME = 'tensorflow.python.test_module'
+
+
+class CreatePythonApiTest(test.TestCase):
+
+ def setUp(self):
+ # Add fake op to a module that has 'tensorflow' in the name.
+ sys.modules[_MODULE_NAME] = imp.new_module(_MODULE_NAME)
+ setattr(sys.modules[_MODULE_NAME], 'test_op', test_op)
+ setattr(sys.modules[_MODULE_NAME], 'TestClass', TestClass)
+ test_op.__module__ = _MODULE_NAME
+ TestClass.__module__ = _MODULE_NAME
+ tf_export('consts._TEST_CONSTANT').export_constant(
+ _MODULE_NAME, '_TEST_CONSTANT')
+
+ def tearDown(self):
+ del sys.modules[_MODULE_NAME]
+
+ def testFunctionImportIsAdded(self):
+ imports = create_python_api.get_api_init_text(
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow')
+ expected_import = (
+ 'from tensorflow.python.test_module '
+ 'import test_op as test_op1')
+ self.assertTrue(
+ expected_import in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ expected_import = ('from tensorflow.python.test_module '
+ 'import test_op')
+ self.assertTrue(
+ expected_import in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ def testClassImportIsAdded(self):
+ imports = create_python_api.get_api_init_text(
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow')
+ expected_import = ('from tensorflow.python.test_module '
+ 'import TestClass')
+ self.assertTrue(
+ 'TestClass' in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ def testConstantIsAdded(self):
+ imports = create_python_api.get_api_init_text(
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow')
+ expected = ('from tensorflow.python.test_module '
+ 'import _TEST_CONSTANT')
+ self.assertTrue(expected in str(imports),
+ msg='%s not in %s' % (expected, str(imports)))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py
new file mode 100644
index 0000000000..ad1988494d
--- /dev/null
+++ b/tensorflow/python/tools/api/generator/doc_srcs.py
@@ -0,0 +1,92 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Specifies sources of doc strings for API modules."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.util import tf_export
+
+
+# Specifies docstring source for a module.
+# Only one of docstring or docstring_module_name should be set.
+# * If docstring is set, then we will use this docstring when
+# for the module.
+# * If docstring_module_name is set, then we will copy the docstring
+# from docstring source module.
+DocSource = collections.namedtuple(
+ 'DocSource', ['docstring', 'docstring_module_name'])
+# Each attribute of DocSource is optional.
+DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
+
+_TENSORFLOW_DOC_SOURCES = {
+ 'app': DocSource(docstring_module_name='platform.app'),
+ 'compat': DocSource(docstring_module_name='util.compat'),
+ 'distributions': DocSource(
+ docstring_module_name='ops.distributions.distributions'),
+ 'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
+ 'errors': DocSource(docstring_module_name='framework.errors'),
+ 'gfile': DocSource(docstring_module_name='platform.gfile'),
+ 'graph_util': DocSource(docstring_module_name='framework.graph_util'),
+ 'image': DocSource(docstring_module_name='ops.image_ops'),
+ 'keras.estimator': DocSource(docstring_module_name='keras.estimator'),
+ 'linalg': DocSource(docstring_module_name='ops.linalg_ops'),
+ 'logging': DocSource(docstring_module_name='ops.logging_ops'),
+ 'losses': DocSource(docstring_module_name='ops.losses.losses'),
+ 'manip': DocSource(docstring_module_name='ops.manip_ops'),
+ 'math': DocSource(docstring_module_name='ops.math_ops'),
+ 'metrics': DocSource(docstring_module_name='ops.metrics'),
+ 'nn': DocSource(docstring_module_name='ops.nn_ops'),
+ 'nn.rnn_cell': DocSource(docstring_module_name='ops.rnn_cell'),
+ 'python_io': DocSource(docstring_module_name='lib.io.python_io'),
+ 'resource_loader': DocSource(
+ docstring_module_name='platform.resource_loader'),
+ 'sets': DocSource(docstring_module_name='ops.sets'),
+ 'sparse': DocSource(docstring_module_name='ops.sparse_ops'),
+ 'spectral': DocSource(docstring_module_name='ops.spectral_ops'),
+ 'strings': DocSource(docstring_module_name='ops.string_ops'),
+ 'sysconfig': DocSource(docstring_module_name='platform.sysconfig'),
+ 'test': DocSource(docstring_module_name='platform.test'),
+ 'train': DocSource(docstring_module_name='training.training'),
+ 'train.queue_runner': DocSource(
+ docstring_module_name='training.queue_runner'),
+}
+
+_ESTIMATOR_DOC_SOURCES = {
+ 'estimator': DocSource(
+ docstring_module_name='estimator_lib'),
+ 'estimator.export': DocSource(
+ docstring_module_name='export.export_lib'),
+ 'estimator.inputs': DocSource(
+ docstring_module_name='inputs.inputs'),
+}
+
+
+def get_doc_sources(api_name):
+ """Get a map from module to a DocSource object.
+
+ Args:
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+
+ Returns:
+ Map from module name to DocSource object.
+ """
+ if api_name == tf_export.TENSORFLOW_API_NAME:
+ return _TENSORFLOW_DOC_SOURCES
+ if api_name == tf_export.ESTIMATOR_API_NAME:
+ return _ESTIMATOR_DOC_SOURCES
+ return {}
diff --git a/tensorflow/python/tools/api/generator/doc_srcs_test.py b/tensorflow/python/tools/api/generator/doc_srcs_test.py
new file mode 100644
index 0000000000..481d9874a4
--- /dev/null
+++ b/tensorflow/python/tools/api/generator/doc_srcs_test.py
@@ -0,0 +1,83 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for tensorflow.python.tools.api.generator.doc_srcs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import importlib
+import sys
+
+from tensorflow.python.platform import test
+from tensorflow.python.tools.api.generator import doc_srcs
+
+
+FLAGS = None
+
+
+class DocSrcsTest(test.TestCase):
+
+ def testModulesAreValidAPIModules(self):
+ for module_name in doc_srcs.get_doc_sources(FLAGS.api_name):
+ # Convert module_name to corresponding __init__.py file path.
+ file_path = module_name.replace('.', '/')
+ if file_path:
+ file_path += '/'
+ file_path += '__init__.py'
+
+ self.assertIn(
+ file_path, FLAGS.outputs,
+ msg='%s is not a valid API module' % module_name)
+
+ def testHaveDocstringOrDocstringModule(self):
+ for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
+ self.assertFalse(
+ docsrc.docstring and docsrc.docstring_module_name,
+ msg=('%s contains DocSource has both a docstring and a '
+ 'docstring_module_name. Only one of "docstring" or '
+ '"docstring_module_name" should be set.') % (module_name))
+
+ def testDocstringModulesAreValidModules(self):
+ for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
+ if docsrc.docstring_module_name:
+ doc_module_name = '.'.join([
+ FLAGS.package, docsrc.docstring_module_name])
+ self.assertIn(
+ doc_module_name, sys.modules,
+ msg=('docsources_module %s is not a valid module under %s.' %
+ (docsrc.docstring_module_name, FLAGS.package)))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'outputs', metavar='O', type=str, nargs='+',
+ help='create_python_api output files.')
+ parser.add_argument(
+ '--package', type=str,
+ help='Base package that imports modules containing the target tf_export '
+ 'decorators.')
+ parser.add_argument(
+ '--api_name', type=str,
+ help='API name: tensorflow or estimator')
+ FLAGS, unparsed = parser.parse_known_args()
+
+ importlib.import_module(FLAGS.package)
+
+ # Now update argv, so that unittest library does not get confused.
+ sys.argv = [sys.argv[0]] + unparsed
+ test.main()