diff options
author | Michael Case <mikecase@google.com> | 2018-07-12 11:56:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-12 12:04:28 -0700 |
commit | c35bd2e9d3d9311bc7fb0f2463869faf1a8a7b50 (patch) | |
tree | b0aebeae9cb451bb0938d997c20cd78738b7ae5f /tensorflow/python/tools | |
parent | 0678f10d0f96b46ecabf129cd69a04de2df49a3d (diff) |
Internal Change.
PiperOrigin-RevId: 204338153
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r-- | tensorflow/python/tools/api/generator/BUILD | 84 | ||||
-rw-r--r-- | tensorflow/python/tools/api/generator/api_gen.bzl | 164 | ||||
-rw-r--r-- | tensorflow/python/tools/api/generator/create_python_api.py | 408 | ||||
-rw-r--r-- | tensorflow/python/tools/api/generator/create_python_api_test.py | 99 | ||||
-rw-r--r-- | tensorflow/python/tools/api/generator/doc_srcs.py | 92 | ||||
-rw-r--r-- | tensorflow/python/tools/api/generator/doc_srcs_test.py | 83 |
6 files changed, 930 insertions, 0 deletions
diff --git a/tensorflow/python/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD new file mode 100644 index 0000000000..223d1281ba --- /dev/null +++ b/tensorflow/python/tools/api/generator/BUILD @@ -0,0 +1,84 @@ +# Description: +# Scripts used to generate TensorFlow Python API. + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES") + +exports_files( + [ + "LICENSE", + "create_python_api.py", + ], +) + +py_binary( + name = "create_python_api", + srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"], + main = "//tensorflow/python/tools/api/generator:create_python_api.py", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:no_contrib", + "//tensorflow/python/tools/api/generator:doc_srcs", + ], +) + +py_library( + name = "doc_srcs", + srcs = ["doc_srcs.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:util", + ], +) + +py_test( + name = "create_python_api_test", + srcs = [ + "create_python_api.py", + "create_python_api_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":doc_srcs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", + ], +) + +py_test( + name = "tensorflow_doc_srcs_test", + srcs = ["doc_srcs_test.py"], + args = [ + "--package=tensorflow.python", + "--api_name=tensorflow", + ] + TENSORFLOW_API_INIT_FILES, + main = "doc_srcs_test.py", + srcs_version = "PY2AND3", + deps = [ + ":doc_srcs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", + ], +) + +py_test( + name = "estimator_doc_srcs_test", + srcs = ["doc_srcs_test.py"], + args = [ + "--package=tensorflow.python.estimator", + "--api_name=estimator", + ] + ESTIMATOR_API_INIT_FILES, + main = "doc_srcs_test.py", + srcs_version = "PY2AND3", + deps = [ + ":doc_srcs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", + "//tensorflow/python/estimator:estimator_py", + ], +) diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl new file mode 100644 index 0000000000..f9170610b9 --- /dev/null +++ b/tensorflow/python/tools/api/generator/api_gen.bzl @@ -0,0 +1,164 @@ +"""Targets for generating TensorFlow Python API __init__.py files.""" + +# keep sorted +TENSORFLOW_API_INIT_FILES = [ + # BEGIN GENERATED FILES + "__init__.py", + "app/__init__.py", + "bitwise/__init__.py", + "compat/__init__.py", + "data/__init__.py", + "debugging/__init__.py", + "distributions/__init__.py", + "distributions/bijectors/__init__.py", + "dtypes/__init__.py", + "errors/__init__.py", + "feature_column/__init__.py", + "gfile/__init__.py", + "graph_util/__init__.py", + "image/__init__.py", + "io/__init__.py", + "initializers/__init__.py", + "keras/__init__.py", + "keras/activations/__init__.py", + "keras/applications/__init__.py", + "keras/applications/densenet/__init__.py", + "keras/applications/inception_resnet_v2/__init__.py", + "keras/applications/inception_v3/__init__.py", + "keras/applications/mobilenet/__init__.py", + "keras/applications/nasnet/__init__.py", + "keras/applications/resnet50/__init__.py", + "keras/applications/vgg16/__init__.py", + "keras/applications/vgg19/__init__.py", + "keras/applications/xception/__init__.py", + "keras/backend/__init__.py", + "keras/callbacks/__init__.py", + "keras/constraints/__init__.py", + "keras/datasets/__init__.py", + "keras/datasets/boston_housing/__init__.py", + "keras/datasets/cifar10/__init__.py", + "keras/datasets/cifar100/__init__.py", + "keras/datasets/fashion_mnist/__init__.py", + "keras/datasets/imdb/__init__.py", + "keras/datasets/mnist/__init__.py", + "keras/datasets/reuters/__init__.py", + "keras/estimator/__init__.py", + "keras/initializers/__init__.py", + "keras/layers/__init__.py", + "keras/losses/__init__.py", + "keras/metrics/__init__.py", + "keras/models/__init__.py", + "keras/optimizers/__init__.py", + "keras/preprocessing/__init__.py", + "keras/preprocessing/image/__init__.py", + "keras/preprocessing/sequence/__init__.py", + "keras/preprocessing/text/__init__.py", + "keras/regularizers/__init__.py", + "keras/utils/__init__.py", + "keras/wrappers/__init__.py", + "keras/wrappers/scikit_learn/__init__.py", + "layers/__init__.py", + "linalg/__init__.py", + "logging/__init__.py", + "losses/__init__.py", + "manip/__init__.py", + "math/__init__.py", + "metrics/__init__.py", + "nn/__init__.py", + "nn/rnn_cell/__init__.py", + "profiler/__init__.py", + "python_io/__init__.py", + "quantization/__init__.py", + "resource_loader/__init__.py", + "strings/__init__.py", + "saved_model/__init__.py", + "saved_model/builder/__init__.py", + "saved_model/constants/__init__.py", + "saved_model/loader/__init__.py", + "saved_model/main_op/__init__.py", + "saved_model/signature_constants/__init__.py", + "saved_model/signature_def_utils/__init__.py", + "saved_model/tag_constants/__init__.py", + "saved_model/utils/__init__.py", + "sets/__init__.py", + "sparse/__init__.py", + "spectral/__init__.py", + "summary/__init__.py", + "sysconfig/__init__.py", + "test/__init__.py", + "train/__init__.py", + "train/queue_runner/__init__.py", + "user_ops/__init__.py", + # END GENERATED FILES +] + +# keep sorted +ESTIMATOR_API_INIT_FILES = [ + # BEGIN GENERATED ESTIMATOR FILES + "__init__.py", + "estimator/__init__.py", + "estimator/export/__init__.py", + "estimator/inputs/__init__.py", + # END GENERATED ESTIMATOR FILES +] + +# Creates a genrule that generates a directory structure with __init__.py +# files that import all exported modules (i.e. modules with tf_export +# decorators). +# +# Args: +# name: name of genrule to create. +# output_files: List of __init__.py files that should be generated. +# This list should include file name for every module exported using +# tf_export. For e.g. if an op is decorated with +# @tf_export('module1.module2', 'module3'). Then, output_files should +# include module1/module2/__init__.py and module3/__init__.py. +# root_init_template: Python init file that should be used as template for +# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this +# template will be replaced with root imports collected by this genrule. +# srcs: genrule sources. If passing root_init_template, the template file +# must be included in sources. +# api_name: Name of the project that you want to generate API files for +# (e.g. "tensorflow" or "estimator"). +# package: Python package containing the @tf_export decorators you want to +# process +# package_dep: Python library target containing your package. + +def gen_api_init_files( + name, + output_files = TENSORFLOW_API_INIT_FILES, + root_init_template = None, + srcs = [], + api_name = "tensorflow", + package = "tensorflow.python", + package_dep = "//tensorflow/python:no_contrib", + output_package = "tensorflow"): + root_init_template_flag = "" + if root_init_template: + root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" + + api_gen_binary_target = "create_" + package + "_api" + native.py_binary( + name = "create_" + package + "_api", + srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"], + main = "//tensorflow/python/tools/api/generator:create_python_api.py", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + package_dep, + "//tensorflow/python/tools/api/generator:doc_srcs", + ], + ) + + native.genrule( + name = name, + outs = output_files, + cmd = ( + "$(location :" + api_gen_binary_target + ") " + + root_init_template_flag + " --apidir=$(@D) --apiname=" + + api_name + " --package=" + package + " --output_package=" + + output_package + " $(OUTS)"), + srcs = srcs, + tools = [":" + api_gen_binary_target ], + visibility = ["//tensorflow:__pkg__"], + ) diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py new file mode 100644 index 0000000000..e78fe4b738 --- /dev/null +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -0,0 +1,408 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Generates and prints out imports and constants for new TensorFlow python api. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import collections +import importlib +import os +import sys + +from tensorflow.python.tools.api.generator import doc_srcs +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_export + +API_ATTRS = tf_export.API_ATTRS + +_DEFAULT_PACKAGE = 'tensorflow.python' +_GENFILES_DIR_SUFFIX = 'genfiles/' +_SYMBOLS_TO_SKIP_EXPLICITLY = { + # Overrides __getattr__, so that unwrapping tf_decorator + # would have side effects. + 'tensorflow.python.platform.flags.FLAGS' +} +_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit. +# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script. +\"\"\"%s +\"\"\" + +from __future__ import print_function + +""" +_GENERATED_FILE_FOOTER = '\n\ndel print_function\n' + + +class SymbolExposedTwiceError(Exception): + """Raised when different symbols are exported with the same name.""" + pass + + +def format_import(source_module_name, source_name, dest_name): + """Formats import statement. + + Args: + source_module_name: (string) Source module to import from. + source_name: (string) Source symbol name to import. + dest_name: (string) Destination alias name. + + Returns: + An import statement string. + """ + if source_module_name: + if source_name == dest_name: + return 'from %s import %s' % (source_module_name, source_name) + else: + return 'from %s import %s as %s' % ( + source_module_name, source_name, dest_name) + else: + if source_name == dest_name: + return 'import %s' % source_name + else: + return 'import %s as %s' % (source_name, dest_name) + + +class _ModuleInitCodeBuilder(object): + """Builds a map from module name to imports included in that module.""" + + def __init__(self): + self.module_imports = collections.defaultdict( + lambda: collections.defaultdict(set)) + self._dest_import_to_id = collections.defaultdict(int) + # Names that start with underscore in the root module. + self._underscore_names_in_root = [] + + def add_import( + self, symbol_id, dest_module_name, source_module_name, source_name, + dest_name): + """Adds this import to module_imports. + + Args: + symbol_id: (number) Unique identifier of the symbol to import. + dest_module_name: (string) Module name to add import to. + source_module_name: (string) Module to import from. + source_name: (string) Name of the symbol to import. + dest_name: (string) Import the symbol using this name. + + Raises: + SymbolExposedTwiceError: Raised when an import with the same + dest_name has already been added to dest_module_name. + """ + import_str = format_import(source_module_name, source_name, dest_name) + + # Check if we are trying to expose two different symbols with same name. + full_api_name = dest_name + if dest_module_name: + full_api_name = dest_module_name + '.' + full_api_name + if (full_api_name in self._dest_import_to_id and + symbol_id != self._dest_import_to_id[full_api_name] and + symbol_id != -1): + raise SymbolExposedTwiceError( + 'Trying to export multiple symbols with same name: %s.' % + full_api_name) + self._dest_import_to_id[full_api_name] = symbol_id + + if not dest_module_name and dest_name.startswith('_'): + self._underscore_names_in_root.append(dest_name) + + # The same symbol can be available in multiple modules. + # We store all possible ways of importing this symbol and later pick just + # one. + self.module_imports[dest_module_name][full_api_name].add(import_str) + + def build(self): + """Get a map from destination module to __init__.py code for that module. + + Returns: + A dictionary where + key: (string) destination module (for e.g. tf or tf.consts). + value: (string) text that should be in __init__.py files for + corresponding modules. + """ + module_text_map = {} + for dest_module, dest_name_to_imports in self.module_imports.items(): + # Sort all possible imports for a symbol and pick the first one. + imports_list = [ + sorted(imports)[0] + for _, imports in dest_name_to_imports.items()] + module_text_map[dest_module] = '\n'.join(sorted(imports_list)) + + # Expose exported symbols with underscores in root module + # since we import from it using * import. + underscore_names_str = ', '.join( + '\'%s\'' % name for name in self._underscore_names_in_root) + # We will always generate a root __init__.py file to let us handle * + # imports consistently. Be sure to have a root __init__.py file listed in + # the script outputs. + module_text_map[''] = module_text_map.get('', '') + ''' +_names_with_underscore = [%s] +__all__ = [_s for _s in dir() if not _s.startswith('_')] +__all__.extend([_s for _s in _names_with_underscore]) +__all__.remove('print_function') +''' % underscore_names_str + + return module_text_map + + +def get_api_init_text(package, output_package, api_name): + """Get a map from destination module to __init__.py code for that module. + + Args: + package: Base python package containing python with target tf_export + decorators. + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + + Returns: + A dictionary where + key: (string) destination module (for e.g. tf or tf.consts). + value: (string) text that should be in __init__.py files for + corresponding modules. + """ + module_code_builder = _ModuleInitCodeBuilder() + + # Traverse over everything imported above. Specifically, + # we want to traverse over TensorFlow Python modules. + for module in list(sys.modules.values()): + # Only look at tensorflow modules. + if (not module or not hasattr(module, '__name__') or + module.__name__ is None or package not in module.__name__): + continue + # Do not generate __init__.py files for contrib modules for now. + if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'): + continue + + for module_contents_name in dir(module): + if (module.__name__ + '.' + module_contents_name + in _SYMBOLS_TO_SKIP_EXPLICITLY): + continue + attr = getattr(module, module_contents_name) + + # If attr is _tf_api_constants attribute, then add the constants. + if module_contents_name == API_ATTRS[api_name].constants: + for exports, value in attr: + for export in exports: + names = export.split('.') + dest_module = '.'.join(names[:-1]) + module_code_builder.add_import( + -1, dest_module, module.__name__, value, names[-1]) + continue + + _, attr = tf_decorator.unwrap(attr) + # If attr is a symbol with _tf_api_names attribute, then + # add import for it. + if (hasattr(attr, '__dict__') and + API_ATTRS[api_name].names in attr.__dict__): + for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access + names = export.split('.') + dest_module = '.'.join(names[:-1]) + module_code_builder.add_import( + id(attr), dest_module, module.__name__, module_contents_name, + names[-1]) + + # Import all required modules in their parent modules. + # For e.g. if we import 'foo.bar.Value'. Then, we also + # import 'bar' in 'foo'. + imported_modules = set(module_code_builder.module_imports.keys()) + for module in imported_modules: + if not module: + continue + module_split = module.split('.') + parent_module = '' # we import submodules in their parent_module + + for submodule_index in range(len(module_split)): + if submodule_index > 0: + parent_module += ('.' + module_split[submodule_index-1] if parent_module + else module_split[submodule_index-1]) + import_from = output_package + if submodule_index > 0: + import_from += '.' + '.'.join(module_split[:submodule_index]) + module_code_builder.add_import( + -1, parent_module, import_from, + module_split[submodule_index], module_split[submodule_index]) + + return module_code_builder.build() + + +def get_module(dir_path, relative_to_dir): + """Get module that corresponds to path relative to relative_to_dir. + + Args: + dir_path: Path to directory. + relative_to_dir: Get module relative to this directory. + + Returns: + Name of module that corresponds to the given directory. + """ + dir_path = dir_path[len(relative_to_dir):] + # Convert path separators to '/' for easier parsing below. + dir_path = dir_path.replace(os.sep, '/') + return dir_path.replace('/', '.').strip('.') + + +def get_module_docstring(module_name, package, api_name): + """Get docstring for the given module. + + This method looks for docstring in the following order: + 1. Checks if module has a docstring specified in doc_srcs. + 2. Checks if module has a docstring source module specified + in doc_srcs. If it does, gets docstring from that module. + 3. Checks if module with module_name exists under base package. + If it does, gets docstring from that module. + 4. Returns a default docstring. + + Args: + module_name: module name relative to tensorflow + (excluding 'tensorflow.' prefix) to get a docstring for. + package: Base python package containing python with target tf_export + decorators. + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + + Returns: + One-line docstring to describe the module. + """ + # Module under base package to get a docstring from. + docstring_module_name = module_name + + doc_sources = doc_srcs.get_doc_sources(api_name) + + if module_name in doc_sources: + docsrc = doc_sources[module_name] + if docsrc.docstring: + return docsrc.docstring + if docsrc.docstring_module_name: + docstring_module_name = docsrc.docstring_module_name + + docstring_module_name = package + '.' + docstring_module_name + if (docstring_module_name in sys.modules and + sys.modules[docstring_module_name].__doc__): + return sys.modules[docstring_module_name].__doc__ + + return 'Public API for tf.%s namespace.' % module_name + + +def create_api_files( + output_files, package, root_init_template, output_dir, output_package, + api_name): + """Creates __init__.py files for the Python API. + + Args: + output_files: List of __init__.py file paths to create. + Each file must be under api/ directory. + package: Base python package containing python with target tf_export + decorators. + root_init_template: Template for top-level __init__.py file. + "#API IMPORTS PLACEHOLDER" comment in the template file will be replaced + with imports. + output_dir: output API root directory. + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + + Raises: + ValueError: if an output file is not under api/ directory, + or output_files list is missing a required file. + """ + module_name_to_file_path = {} + for output_file in output_files: + module_name = get_module(os.path.dirname(output_file), output_dir) + module_name_to_file_path[module_name] = os.path.normpath(output_file) + + # Create file for each expected output in genrule. + for module, file_path in module_name_to_file_path.items(): + if not os.path.isdir(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + open(file_path, 'a').close() + + module_text_map = get_api_init_text(package, output_package, api_name) + + # Add imports to output files. + missing_output_files = [] + for module, text in module_text_map.items(): + # Make sure genrule output file list is in sync with API exports. + if module not in module_name_to_file_path: + module_file_path = '"%s/__init__.py"' % ( + module.replace('.', '/')) + missing_output_files.append(module_file_path) + continue + contents = '' + if module or not root_init_template: + contents = ( + _GENERATED_FILE_HEADER % + get_module_docstring(module, package, api_name) + + text + _GENERATED_FILE_FOOTER) + else: + # Read base init file + with open(root_init_template, 'r') as root_init_template_file: + contents = root_init_template_file.read() + contents = contents.replace('# API IMPORTS PLACEHOLDER', text) + with open(module_name_to_file_path[module], 'w') as fp: + fp.write(contents) + + if missing_output_files: + raise ValueError( + 'Missing outputs for python_api_gen genrule:\n%s.' + 'Make sure all required outputs are in the ' + 'tensorflow/tools/api/generator/api_gen.bzl file.' % + ',\n'.join(sorted(missing_output_files))) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + 'outputs', metavar='O', type=str, nargs='+', + help='If a single file is passed in, then we we assume it contains a ' + 'semicolon-separated list of Python files that we expect this script to ' + 'output. If multiple files are passed in, then we assume output files ' + 'are listed directly as arguments.') + parser.add_argument( + '--package', default=_DEFAULT_PACKAGE, type=str, + help='Base package that imports modules containing the target tf_export ' + 'decorators.') + parser.add_argument( + '--root_init_template', default='', type=str, + help='Template for top level __init__.py file. ' + '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.') + parser.add_argument( + '--apidir', type=str, required=True, + help='Directory where generated output files are placed. ' + 'gendir should be a prefix of apidir. Also, apidir ' + 'should be a prefix of every directory in outputs.') + parser.add_argument( + '--apiname', required=True, type=str, + choices=API_ATTRS.keys(), + help='The API you want to generate.') + parser.add_argument( + '--output_package', default='tensorflow', type=str, + help='Root output package.') + + args = parser.parse_args() + + if len(args.outputs) == 1: + # If we only get a single argument, then it must be a file containing + # list of outputs. + with open(args.outputs[0]) as output_list_file: + outputs = [line.strip() for line in output_list_file.read().split(';')] + else: + outputs = args.outputs + + # Populate `sys.modules` with modules containing tf_export(). + importlib.import_module(args.package) + create_api_files(outputs, args.package, args.root_init_template, + args.apidir, args.output_package, args.apiname) + + +if __name__ == '__main__': + main() diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py new file mode 100644 index 0000000000..368b4c37e8 --- /dev/null +++ b/tensorflow/python/tools/api/generator/create_python_api_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for create_python_api.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import imp +import sys + +from tensorflow.python.platform import test +from tensorflow.python.tools.api.generator import create_python_api +from tensorflow.python.util.tf_export import tf_export + + +@tf_export('test_op', 'test_op1') +def test_op(): + pass + + +@tf_export('TestClass', 'NewTestClass') +class TestClass(object): + pass + + +_TEST_CONSTANT = 5 +_MODULE_NAME = 'tensorflow.python.test_module' + + +class CreatePythonApiTest(test.TestCase): + + def setUp(self): + # Add fake op to a module that has 'tensorflow' in the name. + sys.modules[_MODULE_NAME] = imp.new_module(_MODULE_NAME) + setattr(sys.modules[_MODULE_NAME], 'test_op', test_op) + setattr(sys.modules[_MODULE_NAME], 'TestClass', TestClass) + test_op.__module__ = _MODULE_NAME + TestClass.__module__ = _MODULE_NAME + tf_export('consts._TEST_CONSTANT').export_constant( + _MODULE_NAME, '_TEST_CONSTANT') + + def tearDown(self): + del sys.modules[_MODULE_NAME] + + def testFunctionImportIsAdded(self): + imports = create_python_api.get_api_init_text( + package=create_python_api._DEFAULT_PACKAGE, + output_package='tensorflow', + api_name='tensorflow') + expected_import = ( + 'from tensorflow.python.test_module ' + 'import test_op as test_op1') + self.assertTrue( + expected_import in str(imports), + msg='%s not in %s' % (expected_import, str(imports))) + + expected_import = ('from tensorflow.python.test_module ' + 'import test_op') + self.assertTrue( + expected_import in str(imports), + msg='%s not in %s' % (expected_import, str(imports))) + + def testClassImportIsAdded(self): + imports = create_python_api.get_api_init_text( + package=create_python_api._DEFAULT_PACKAGE, + output_package='tensorflow', + api_name='tensorflow') + expected_import = ('from tensorflow.python.test_module ' + 'import TestClass') + self.assertTrue( + 'TestClass' in str(imports), + msg='%s not in %s' % (expected_import, str(imports))) + + def testConstantIsAdded(self): + imports = create_python_api.get_api_init_text( + package=create_python_api._DEFAULT_PACKAGE, + output_package='tensorflow', + api_name='tensorflow') + expected = ('from tensorflow.python.test_module ' + 'import _TEST_CONSTANT') + self.assertTrue(expected in str(imports), + msg='%s not in %s' % (expected, str(imports))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py new file mode 100644 index 0000000000..ad1988494d --- /dev/null +++ b/tensorflow/python/tools/api/generator/doc_srcs.py @@ -0,0 +1,92 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Specifies sources of doc strings for API modules.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.util import tf_export + + +# Specifies docstring source for a module. +# Only one of docstring or docstring_module_name should be set. +# * If docstring is set, then we will use this docstring when +# for the module. +# * If docstring_module_name is set, then we will copy the docstring +# from docstring source module. +DocSource = collections.namedtuple( + 'DocSource', ['docstring', 'docstring_module_name']) +# Each attribute of DocSource is optional. +DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields) + +_TENSORFLOW_DOC_SOURCES = { + 'app': DocSource(docstring_module_name='platform.app'), + 'compat': DocSource(docstring_module_name='util.compat'), + 'distributions': DocSource( + docstring_module_name='ops.distributions.distributions'), + 'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'), + 'errors': DocSource(docstring_module_name='framework.errors'), + 'gfile': DocSource(docstring_module_name='platform.gfile'), + 'graph_util': DocSource(docstring_module_name='framework.graph_util'), + 'image': DocSource(docstring_module_name='ops.image_ops'), + 'keras.estimator': DocSource(docstring_module_name='keras.estimator'), + 'linalg': DocSource(docstring_module_name='ops.linalg_ops'), + 'logging': DocSource(docstring_module_name='ops.logging_ops'), + 'losses': DocSource(docstring_module_name='ops.losses.losses'), + 'manip': DocSource(docstring_module_name='ops.manip_ops'), + 'math': DocSource(docstring_module_name='ops.math_ops'), + 'metrics': DocSource(docstring_module_name='ops.metrics'), + 'nn': DocSource(docstring_module_name='ops.nn_ops'), + 'nn.rnn_cell': DocSource(docstring_module_name='ops.rnn_cell'), + 'python_io': DocSource(docstring_module_name='lib.io.python_io'), + 'resource_loader': DocSource( + docstring_module_name='platform.resource_loader'), + 'sets': DocSource(docstring_module_name='ops.sets'), + 'sparse': DocSource(docstring_module_name='ops.sparse_ops'), + 'spectral': DocSource(docstring_module_name='ops.spectral_ops'), + 'strings': DocSource(docstring_module_name='ops.string_ops'), + 'sysconfig': DocSource(docstring_module_name='platform.sysconfig'), + 'test': DocSource(docstring_module_name='platform.test'), + 'train': DocSource(docstring_module_name='training.training'), + 'train.queue_runner': DocSource( + docstring_module_name='training.queue_runner'), +} + +_ESTIMATOR_DOC_SOURCES = { + 'estimator': DocSource( + docstring_module_name='estimator_lib'), + 'estimator.export': DocSource( + docstring_module_name='export.export_lib'), + 'estimator.inputs': DocSource( + docstring_module_name='inputs.inputs'), +} + + +def get_doc_sources(api_name): + """Get a map from module to a DocSource object. + + Args: + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + + Returns: + Map from module name to DocSource object. + """ + if api_name == tf_export.TENSORFLOW_API_NAME: + return _TENSORFLOW_DOC_SOURCES + if api_name == tf_export.ESTIMATOR_API_NAME: + return _ESTIMATOR_DOC_SOURCES + return {} diff --git a/tensorflow/python/tools/api/generator/doc_srcs_test.py b/tensorflow/python/tools/api/generator/doc_srcs_test.py new file mode 100644 index 0000000000..481d9874a4 --- /dev/null +++ b/tensorflow/python/tools/api/generator/doc_srcs_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for tensorflow.python.tools.api.generator.doc_srcs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import importlib +import sys + +from tensorflow.python.platform import test +from tensorflow.python.tools.api.generator import doc_srcs + + +FLAGS = None + + +class DocSrcsTest(test.TestCase): + + def testModulesAreValidAPIModules(self): + for module_name in doc_srcs.get_doc_sources(FLAGS.api_name): + # Convert module_name to corresponding __init__.py file path. + file_path = module_name.replace('.', '/') + if file_path: + file_path += '/' + file_path += '__init__.py' + + self.assertIn( + file_path, FLAGS.outputs, + msg='%s is not a valid API module' % module_name) + + def testHaveDocstringOrDocstringModule(self): + for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): + self.assertFalse( + docsrc.docstring and docsrc.docstring_module_name, + msg=('%s contains DocSource has both a docstring and a ' + 'docstring_module_name. Only one of "docstring" or ' + '"docstring_module_name" should be set.') % (module_name)) + + def testDocstringModulesAreValidModules(self): + for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): + if docsrc.docstring_module_name: + doc_module_name = '.'.join([ + FLAGS.package, docsrc.docstring_module_name]) + self.assertIn( + doc_module_name, sys.modules, + msg=('docsources_module %s is not a valid module under %s.' % + (docsrc.docstring_module_name, FLAGS.package))) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + 'outputs', metavar='O', type=str, nargs='+', + help='create_python_api output files.') + parser.add_argument( + '--package', type=str, + help='Base package that imports modules containing the target tf_export ' + 'decorators.') + parser.add_argument( + '--api_name', type=str, + help='API name: tensorflow or estimator') + FLAGS, unparsed = parser.parse_known_args() + + importlib.import_module(FLAGS.package) + + # Now update argv, so that unittest library does not get confused. + sys.argv = [sys.argv[0]] + unparsed + test.main() |