aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/api/generator/create_python_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/api/generator/create_python_api.py')
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api.py425
1 files changed, 425 insertions, 0 deletions
diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py
new file mode 100644
index 0000000000..863c922216
--- /dev/null
+++ b/tensorflow/python/tools/api/generator/create_python_api.py
@@ -0,0 +1,425 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Generates and prints out imports and constants for new TensorFlow python api.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import collections
+import importlib
+import os
+import sys
+
+from tensorflow.python.tools.api.generator import doc_srcs
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_export
+
+API_ATTRS = tf_export.API_ATTRS
+API_ATTRS_V1 = tf_export.API_ATTRS_V1
+
+_DEFAULT_PACKAGE = 'tensorflow.python'
+_GENFILES_DIR_SUFFIX = 'genfiles/'
+_SYMBOLS_TO_SKIP_EXPLICITLY = {
+ # Overrides __getattr__, so that unwrapping tf_decorator
+ # would have side effects.
+ 'tensorflow.python.platform.flags.FLAGS'
+}
+_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
+# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
+\"\"\"%s
+\"\"\"
+
+from __future__ import print_function
+
+"""
+_GENERATED_FILE_FOOTER = '\n\ndel print_function\n'
+
+
+class SymbolExposedTwiceError(Exception):
+ """Raised when different symbols are exported with the same name."""
+ pass
+
+
+def format_import(source_module_name, source_name, dest_name):
+ """Formats import statement.
+
+ Args:
+ source_module_name: (string) Source module to import from.
+ source_name: (string) Source symbol name to import.
+ dest_name: (string) Destination alias name.
+
+ Returns:
+ An import statement string.
+ """
+ if source_module_name:
+ if source_name == dest_name:
+ return 'from %s import %s' % (source_module_name, source_name)
+ else:
+ return 'from %s import %s as %s' % (
+ source_module_name, source_name, dest_name)
+ else:
+ if source_name == dest_name:
+ return 'import %s' % source_name
+ else:
+ return 'import %s as %s' % (source_name, dest_name)
+
+
+class _ModuleInitCodeBuilder(object):
+ """Builds a map from module name to imports included in that module."""
+
+ def __init__(self):
+ self.module_imports = collections.defaultdict(
+ lambda: collections.defaultdict(set))
+ self._dest_import_to_id = collections.defaultdict(int)
+ # Names that start with underscore in the root module.
+ self._underscore_names_in_root = []
+
+ def add_import(
+ self, symbol_id, dest_module_name, source_module_name, source_name,
+ dest_name):
+ """Adds this import to module_imports.
+
+ Args:
+ symbol_id: (number) Unique identifier of the symbol to import.
+ dest_module_name: (string) Module name to add import to.
+ source_module_name: (string) Module to import from.
+ source_name: (string) Name of the symbol to import.
+ dest_name: (string) Import the symbol using this name.
+
+ Raises:
+ SymbolExposedTwiceError: Raised when an import with the same
+ dest_name has already been added to dest_module_name.
+ """
+ import_str = format_import(source_module_name, source_name, dest_name)
+
+ # Check if we are trying to expose two different symbols with same name.
+ full_api_name = dest_name
+ if dest_module_name:
+ full_api_name = dest_module_name + '.' + full_api_name
+ if (full_api_name in self._dest_import_to_id and
+ symbol_id != self._dest_import_to_id[full_api_name] and
+ symbol_id != -1):
+ raise SymbolExposedTwiceError(
+ 'Trying to export multiple symbols with same name: %s.' %
+ full_api_name)
+ self._dest_import_to_id[full_api_name] = symbol_id
+
+ if not dest_module_name and dest_name.startswith('_'):
+ self._underscore_names_in_root.append(dest_name)
+
+ # The same symbol can be available in multiple modules.
+ # We store all possible ways of importing this symbol and later pick just
+ # one.
+ self.module_imports[dest_module_name][full_api_name].add(import_str)
+
+ def build(self):
+ """Get a map from destination module to __init__.py code for that module.
+
+ Returns:
+ A dictionary where
+ key: (string) destination module (for e.g. tf or tf.consts).
+ value: (string) text that should be in __init__.py files for
+ corresponding modules.
+ """
+ module_text_map = {}
+ for dest_module, dest_name_to_imports in self.module_imports.items():
+ # Sort all possible imports for a symbol and pick the first one.
+ imports_list = [
+ sorted(imports)[0]
+ for _, imports in dest_name_to_imports.items()]
+ module_text_map[dest_module] = '\n'.join(sorted(imports_list))
+
+ # Expose exported symbols with underscores in root module
+ # since we import from it using * import.
+ underscore_names_str = ', '.join(
+ '\'%s\'' % name for name in self._underscore_names_in_root)
+ # We will always generate a root __init__.py file to let us handle *
+ # imports consistently. Be sure to have a root __init__.py file listed in
+ # the script outputs.
+ module_text_map[''] = module_text_map.get('', '') + '''
+_names_with_underscore = [%s]
+__all__ = [_s for _s in dir() if not _s.startswith('_')]
+__all__.extend([_s for _s in _names_with_underscore])
+__all__.remove('print_function')
+''' % underscore_names_str
+
+ return module_text_map
+
+
+def get_api_init_text(package, output_package, api_name, api_version):
+ """Get a map from destination module to __init__.py code for that module.
+
+ Args:
+ package: Base python package containing python with target tf_export
+ decorators.
+ output_package: Base output python package where generated API will
+ be added.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+ api_version: API version you want to generate (`v1` or `v2`).
+
+ Returns:
+ A dictionary where
+ key: (string) destination module (for e.g. tf or tf.consts).
+ value: (string) text that should be in __init__.py files for
+ corresponding modules.
+ """
+ if api_version == 1:
+ names_attr = API_ATTRS_V1[api_name].names
+ constants_attr = API_ATTRS_V1[api_name].constants
+ else:
+ names_attr = API_ATTRS[api_name].names
+ constants_attr = API_ATTRS[api_name].constants
+ module_code_builder = _ModuleInitCodeBuilder()
+
+ # Traverse over everything imported above. Specifically,
+ # we want to traverse over TensorFlow Python modules.
+ for module in list(sys.modules.values()):
+ # Only look at tensorflow modules.
+ if (not module or not hasattr(module, '__name__') or
+ module.__name__ is None or package not in module.__name__):
+ continue
+ # Do not generate __init__.py files for contrib modules for now.
+ if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
+ continue
+
+ for module_contents_name in dir(module):
+ if (module.__name__ + '.' + module_contents_name
+ in _SYMBOLS_TO_SKIP_EXPLICITLY):
+ continue
+ attr = getattr(module, module_contents_name)
+
+ # If attr is _tf_api_constants attribute, then add the constants.
+ if module_contents_name == constants_attr:
+ for exports, value in attr:
+ for export in exports:
+ names = export.split('.')
+ dest_module = '.'.join(names[:-1])
+ module_code_builder.add_import(
+ -1, dest_module, module.__name__, value, names[-1])
+ continue
+
+ _, attr = tf_decorator.unwrap(attr)
+ # If attr is a symbol with _tf_api_names attribute, then
+ # add import for it.
+ if (hasattr(attr, '__dict__') and names_attr in attr.__dict__):
+ for export in getattr(attr, names_attr): # pylint: disable=protected-access
+ names = export.split('.')
+ dest_module = '.'.join(names[:-1])
+ module_code_builder.add_import(
+ id(attr), dest_module, module.__name__, module_contents_name,
+ names[-1])
+
+ # Import all required modules in their parent modules.
+ # For e.g. if we import 'foo.bar.Value'. Then, we also
+ # import 'bar' in 'foo'.
+ imported_modules = set(module_code_builder.module_imports.keys())
+ for module in imported_modules:
+ if not module:
+ continue
+ module_split = module.split('.')
+ parent_module = '' # we import submodules in their parent_module
+
+ for submodule_index in range(len(module_split)):
+ if submodule_index > 0:
+ parent_module += ('.' + module_split[submodule_index-1] if parent_module
+ else module_split[submodule_index-1])
+ import_from = output_package
+ if submodule_index > 0:
+ import_from += '.' + '.'.join(module_split[:submodule_index])
+ module_code_builder.add_import(
+ -1, parent_module, import_from,
+ module_split[submodule_index], module_split[submodule_index])
+
+ return module_code_builder.build()
+
+
+def get_module(dir_path, relative_to_dir):
+ """Get module that corresponds to path relative to relative_to_dir.
+
+ Args:
+ dir_path: Path to directory.
+ relative_to_dir: Get module relative to this directory.
+
+ Returns:
+ Name of module that corresponds to the given directory.
+ """
+ dir_path = dir_path[len(relative_to_dir):]
+ # Convert path separators to '/' for easier parsing below.
+ dir_path = dir_path.replace(os.sep, '/')
+ return dir_path.replace('/', '.').strip('.')
+
+
+def get_module_docstring(module_name, package, api_name):
+ """Get docstring for the given module.
+
+ This method looks for docstring in the following order:
+ 1. Checks if module has a docstring specified in doc_srcs.
+ 2. Checks if module has a docstring source module specified
+ in doc_srcs. If it does, gets docstring from that module.
+ 3. Checks if module with module_name exists under base package.
+ If it does, gets docstring from that module.
+ 4. Returns a default docstring.
+
+ Args:
+ module_name: module name relative to tensorflow
+ (excluding 'tensorflow.' prefix) to get a docstring for.
+ package: Base python package containing python with target tf_export
+ decorators.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+
+ Returns:
+ One-line docstring to describe the module.
+ """
+ # Module under base package to get a docstring from.
+ docstring_module_name = module_name
+
+ doc_sources = doc_srcs.get_doc_sources(api_name)
+
+ if module_name in doc_sources:
+ docsrc = doc_sources[module_name]
+ if docsrc.docstring:
+ return docsrc.docstring
+ if docsrc.docstring_module_name:
+ docstring_module_name = docsrc.docstring_module_name
+
+ docstring_module_name = package + '.' + docstring_module_name
+ if (docstring_module_name in sys.modules and
+ sys.modules[docstring_module_name].__doc__):
+ return sys.modules[docstring_module_name].__doc__
+
+ return 'Public API for tf.%s namespace.' % module_name
+
+
+def create_api_files(
+ output_files, package, root_init_template, output_dir, output_package,
+ api_name, api_version):
+ """Creates __init__.py files for the Python API.
+
+ Args:
+ output_files: List of __init__.py file paths to create.
+ Each file must be under api/ directory.
+ package: Base python package containing python with target tf_export
+ decorators.
+ root_init_template: Template for top-level __init__.py file.
+ "#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
+ with imports.
+ output_dir: output API root directory.
+ output_package: Base output package where generated API will be added.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+ api_version: API version to generate (`v1` or `v2`).
+
+ Raises:
+ ValueError: if an output file is not under api/ directory,
+ or output_files list is missing a required file.
+ """
+ module_name_to_file_path = {}
+ for output_file in output_files:
+ module_name = get_module(os.path.dirname(output_file), output_dir)
+ module_name_to_file_path[module_name] = os.path.normpath(output_file)
+
+ # Create file for each expected output in genrule.
+ for module, file_path in module_name_to_file_path.items():
+ if not os.path.isdir(os.path.dirname(file_path)):
+ os.makedirs(os.path.dirname(file_path))
+ open(file_path, 'a').close()
+
+ module_text_map = get_api_init_text(
+ package, output_package, api_name, api_version)
+
+ # Add imports to output files.
+ missing_output_files = []
+ for module, text in module_text_map.items():
+ # Make sure genrule output file list is in sync with API exports.
+ if module not in module_name_to_file_path:
+ module_file_path = '"%s/__init__.py"' % (
+ module.replace('.', '/'))
+ missing_output_files.append(module_file_path)
+ continue
+ contents = ''
+ if module or not root_init_template:
+ contents = (
+ _GENERATED_FILE_HEADER %
+ get_module_docstring(module, package, api_name) +
+ text + _GENERATED_FILE_FOOTER)
+ else:
+ # Read base init file
+ with open(root_init_template, 'r') as root_init_template_file:
+ contents = root_init_template_file.read()
+ contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
+ with open(module_name_to_file_path[module], 'w') as fp:
+ fp.write(contents)
+
+ if missing_output_files:
+ raise ValueError(
+ 'Missing outputs for python_api_gen genrule:\n%s.'
+ 'Make sure all required outputs are in the '
+ 'tensorflow/tools/api/generator/api_gen.bzl file.' %
+ ',\n'.join(sorted(missing_output_files)))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'outputs', metavar='O', type=str, nargs='+',
+ help='If a single file is passed in, then we we assume it contains a '
+ 'semicolon-separated list of Python files that we expect this script to '
+ 'output. If multiple files are passed in, then we assume output files '
+ 'are listed directly as arguments.')
+ parser.add_argument(
+ '--package', default=_DEFAULT_PACKAGE, type=str,
+ help='Base package that imports modules containing the target tf_export '
+ 'decorators.')
+ parser.add_argument(
+ '--root_init_template', default='', type=str,
+ help='Template for top level __init__.py file. '
+ '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
+ parser.add_argument(
+ '--apidir', type=str, required=True,
+ help='Directory where generated output files are placed. '
+ 'gendir should be a prefix of apidir. Also, apidir '
+ 'should be a prefix of every directory in outputs.')
+ parser.add_argument(
+ '--apiname', required=True, type=str,
+ choices=API_ATTRS.keys(),
+ help='The API you want to generate.')
+ parser.add_argument(
+ '--apiversion', default=2, type=int,
+ choices=[1, 2],
+ help='The API version you want to generate.')
+ parser.add_argument(
+ '--output_package', default='tensorflow', type=str,
+ help='Root output package.')
+
+ args = parser.parse_args()
+
+ if len(args.outputs) == 1:
+ # If we only get a single argument, then it must be a file containing
+ # list of outputs.
+ with open(args.outputs[0]) as output_list_file:
+ outputs = [line.strip() for line in output_list_file.read().split(';')]
+ else:
+ outputs = args.outputs
+
+ # Populate `sys.modules` with modules containing tf_export().
+ importlib.import_module(args.package)
+ create_api_files(outputs, args.package, args.root_init_template,
+ args.apidir, args.output_package, args.apiname,
+ args.apiversion)
+
+
+if __name__ == '__main__':
+ main()