diff options
author | Anna R <annarev@google.com> | 2018-07-12 14:46:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-12 14:51:03 -0700 |
commit | da798407b4ff72f1daa629e054ccd47b162c9d58 (patch) | |
tree | 4bc5251f66dd8bb601d73fd3ec8f035b953bbe6a /tensorflow/python/tools | |
parent | c5e563e57feee793499fae9c3ce28f5176404749 (diff) |
Support passing TensorFlow API names as a separate v1 argument to tf_export.
PiperOrigin-RevId: 204368026
Diffstat (limited to 'tensorflow/python/tools')
3 files changed, 56 insertions, 35 deletions
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl index f9170610b9..2a32e8a893 100644 --- a/tensorflow/python/tools/api/generator/api_gen.bzl +++ b/tensorflow/python/tools/api/generator/api_gen.bzl @@ -102,37 +102,41 @@ ESTIMATOR_API_INIT_FILES = [ # END GENERATED ESTIMATOR FILES ] -# Creates a genrule that generates a directory structure with __init__.py -# files that import all exported modules (i.e. modules with tf_export -# decorators). -# -# Args: -# name: name of genrule to create. -# output_files: List of __init__.py files that should be generated. -# This list should include file name for every module exported using -# tf_export. For e.g. if an op is decorated with -# @tf_export('module1.module2', 'module3'). Then, output_files should -# include module1/module2/__init__.py and module3/__init__.py. -# root_init_template: Python init file that should be used as template for -# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this -# template will be replaced with root imports collected by this genrule. -# srcs: genrule sources. If passing root_init_template, the template file -# must be included in sources. -# api_name: Name of the project that you want to generate API files for -# (e.g. "tensorflow" or "estimator"). -# package: Python package containing the @tf_export decorators you want to -# process -# package_dep: Python library target containing your package. - def gen_api_init_files( name, output_files = TENSORFLOW_API_INIT_FILES, root_init_template = None, srcs = [], api_name = "tensorflow", + api_version = 2, package = "tensorflow.python", package_dep = "//tensorflow/python:no_contrib", output_package = "tensorflow"): + """Creates API directory structure and __init__.py files. + + Creates a genrule that generates a directory structure with __init__.py + files that import all exported modules (i.e. modules with tf_export + decorators). + + Args: + name: name of genrule to create. + output_files: List of __init__.py files that should be generated. + This list should include file name for every module exported using + tf_export. For e.g. if an op is decorated with + @tf_export('module1.module2', 'module3'). Then, output_files should + include module1/module2/__init__.py and module3/__init__.py. + root_init_template: Python init file that should be used as template for + root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this + template will be replaced with root imports collected by this genrule. + srcs: genrule sources. If passing root_init_template, the template file + must be included in sources. + api_name: Name of the project that you want to generate API files for + (e.g. "tensorflow" or "estimator"). + api_version: TensorFlow API version to generate. Must be either 1 or 2. + package: Python package containing the @tf_export decorators you want to + process + package_dep: Python library target containing your package. + """ root_init_template_flag = "" if root_init_template: root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" @@ -156,8 +160,8 @@ def gen_api_init_files( cmd = ( "$(location :" + api_gen_binary_target + ") " + root_init_template_flag + " --apidir=$(@D) --apiname=" + - api_name + " --package=" + package + " --output_package=" + - output_package + " $(OUTS)"), + api_name + " --apiversion=" + str(api_version) + " --package=" + package + + " --output_package=" + output_package + " $(OUTS)"), srcs = srcs, tools = [":" + api_gen_binary_target ], visibility = ["//tensorflow:__pkg__"], diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py index e78fe4b738..863c922216 100644 --- a/tensorflow/python/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -29,6 +29,7 @@ from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_export API_ATTRS = tf_export.API_ATTRS +API_ATTRS_V1 = tf_export.API_ATTRS_V1 _DEFAULT_PACKAGE = 'tensorflow.python' _GENFILES_DIR_SUFFIX = 'genfiles/' @@ -159,13 +160,16 @@ __all__.remove('print_function') return module_text_map -def get_api_init_text(package, output_package, api_name): +def get_api_init_text(package, output_package, api_name, api_version): """Get a map from destination module to __init__.py code for that module. Args: package: Base python package containing python with target tf_export decorators. + output_package: Base output python package where generated API will + be added. api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + api_version: API version you want to generate (`v1` or `v2`). Returns: A dictionary where @@ -173,6 +177,12 @@ def get_api_init_text(package, output_package, api_name): value: (string) text that should be in __init__.py files for corresponding modules. """ + if api_version == 1: + names_attr = API_ATTRS_V1[api_name].names + constants_attr = API_ATTRS_V1[api_name].constants + else: + names_attr = API_ATTRS[api_name].names + constants_attr = API_ATTRS[api_name].constants module_code_builder = _ModuleInitCodeBuilder() # Traverse over everything imported above. Specifically, @@ -193,7 +203,7 @@ def get_api_init_text(package, output_package, api_name): attr = getattr(module, module_contents_name) # If attr is _tf_api_constants attribute, then add the constants. - if module_contents_name == API_ATTRS[api_name].constants: + if module_contents_name == constants_attr: for exports, value in attr: for export in exports: names = export.split('.') @@ -205,9 +215,8 @@ def get_api_init_text(package, output_package, api_name): _, attr = tf_decorator.unwrap(attr) # If attr is a symbol with _tf_api_names attribute, then # add import for it. - if (hasattr(attr, '__dict__') and - API_ATTRS[api_name].names in attr.__dict__): - for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access + if (hasattr(attr, '__dict__') and names_attr in attr.__dict__): + for export in getattr(attr, names_attr): # pylint: disable=protected-access names = export.split('.') dest_module = '.'.join(names[:-1]) module_code_builder.add_import( @@ -297,7 +306,7 @@ def get_module_docstring(module_name, package, api_name): def create_api_files( output_files, package, root_init_template, output_dir, output_package, - api_name): + api_name, api_version): """Creates __init__.py files for the Python API. Args: @@ -309,7 +318,9 @@ def create_api_files( "#API IMPORTS PLACEHOLDER" comment in the template file will be replaced with imports. output_dir: output API root directory. + output_package: Base output package where generated API will be added. api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + api_version: API version to generate (`v1` or `v2`). Raises: ValueError: if an output file is not under api/ directory, @@ -326,7 +337,8 @@ def create_api_files( os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - module_text_map = get_api_init_text(package, output_package, api_name) + module_text_map = get_api_init_text( + package, output_package, api_name, api_version) # Add imports to output files. missing_output_files = [] @@ -385,6 +397,10 @@ def main(): choices=API_ATTRS.keys(), help='The API you want to generate.') parser.add_argument( + '--apiversion', default=2, type=int, + choices=[1, 2], + help='The API version you want to generate.') + parser.add_argument( '--output_package', default='tensorflow', type=str, help='Root output package.') @@ -401,7 +417,8 @@ def main(): # Populate `sys.modules` with modules containing tf_export(). importlib.import_module(args.package) create_api_files(outputs, args.package, args.root_init_template, - args.apidir, args.output_package, args.apiname) + args.apidir, args.output_package, args.apiname, + args.apiversion) if __name__ == '__main__': diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py index 368b4c37e8..a565a49d96 100644 --- a/tensorflow/python/tools/api/generator/create_python_api_test.py +++ b/tensorflow/python/tools/api/generator/create_python_api_test.py @@ -59,7 +59,7 @@ class CreatePythonApiTest(test.TestCase): imports = create_python_api.get_api_init_text( package=create_python_api._DEFAULT_PACKAGE, output_package='tensorflow', - api_name='tensorflow') + api_name='tensorflow', api_version=1) expected_import = ( 'from tensorflow.python.test_module ' 'import test_op as test_op1') @@ -77,7 +77,7 @@ class CreatePythonApiTest(test.TestCase): imports = create_python_api.get_api_init_text( package=create_python_api._DEFAULT_PACKAGE, output_package='tensorflow', - api_name='tensorflow') + api_name='tensorflow', api_version=2) expected_import = ('from tensorflow.python.test_module ' 'import TestClass') self.assertTrue( @@ -88,7 +88,7 @@ class CreatePythonApiTest(test.TestCase): imports = create_python_api.get_api_init_text( package=create_python_api._DEFAULT_PACKAGE, output_package='tensorflow', - api_name='tensorflow') + api_name='tensorflow', api_version=1) expected = ('from tensorflow.python.test_module ' 'import _TEST_CONSTANT') self.assertTrue(expected in str(imports), |