diff options
author | 2018-06-07 12:05:24 -0700 | |
---|---|---|
committer | 2018-06-07 12:14:50 -0700 | |
commit | 501cf726cbee2ee13efef43884a6552ca211979d (patch) | |
tree | 2a93bae901b9f9d32f5d622e2e4d626668b48b99 /tensorflow/python/util/tf_export.py | |
parent | 4d0d60a82c52c6c71650db33bf826f03559d91fc (diff) |
Internal Change.
PiperOrigin-RevId: 199673803
Diffstat (limited to 'tensorflow/python/util/tf_export.py')
-rw-r--r-- | tensorflow/python/util/tf_export.py | 58 |
1 files changed, 38 insertions, 20 deletions
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index bf3961c692..e154ffb68a 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -41,17 +41,35 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import functools import sys from tensorflow.python.util import tf_decorator +ESTIMATOR_API_NAME = 'estimator' +TENSORFLOW_API_NAME = 'tensorflow' + +_Attributes = collections.namedtuple( + 'ExportedApiAttributes', ['names', 'constants']) + +# Attribute values must be unique to each API. +API_ATTRS = { + TENSORFLOW_API_NAME: _Attributes( + '_tf_api_names', + '_tf_api_constants'), + ESTIMATOR_API_NAME: _Attributes( + '_estimator_api_names', + '_estimator_api_constants') +} + class SymbolAlreadyExposedError(Exception): """Raised when adding API names to symbol that already has API names.""" pass -class tf_export(object): # pylint: disable=invalid-name +class api_export(object): # pylint: disable=invalid-name """Provides ways to export symbols to the TensorFlow API.""" def __init__(self, *args, **kwargs): @@ -63,15 +81,12 @@ class tf_export(object): # pylint: disable=invalid-name overrides: List of symbols that this is overriding (those overrided api exports will be removed). Note: passing overrides has no effect on exporting a constant. - allow_multiple_exports: Allows exporting the same symbol multiple - times with multiple `tf_export` usages. Prefer however, to list all - of the exported names in a single `tf_export` usage when possible. - + api_name: Name of the API you want to generate (e.g. `tensorflow` or + `estimator`). Default is `tensorflow`. """ self._names = args + self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) self._overrides = kwargs.get('overrides', []) - self._allow_multiple_exports = kwargs.get( - 'allow_multiple_exports', False) def __call__(self, func): """Calls this decorator. @@ -86,25 +101,24 @@ class tf_export(object): # pylint: disable=invalid-name SymbolAlreadyExposedError: Raised when a symbol already has API names and kwarg `allow_multiple_exports` not set. """ + api_names_attr = API_ATTRS[self._api_name].names + # Undecorate overridden names for f in self._overrides: _, undecorated_f = tf_decorator.unwrap(f) - del undecorated_f._tf_api_names # pylint: disable=protected-access + delattr(undecorated_f, api_names_attr) _, undecorated_func = tf_decorator.unwrap(func) # Check for an existing api. We check if attribute name is in # __dict__ instead of using hasattr to verify that subclasses have # their own _tf_api_names as opposed to just inheriting it. - if '_tf_api_names' in undecorated_func.__dict__: - if self._allow_multiple_exports: - undecorated_func._tf_api_names += self._names # pylint: disable=protected-access - else: - raise SymbolAlreadyExposedError( - 'Symbol %s is already exposed as %s.' % - (undecorated_func.__name__, undecorated_func._tf_api_names)) # pylint: disable=protected-access - else: - undecorated_func._tf_api_names = self._names # pylint: disable=protected-access + if api_names_attr in undecorated_func.__dict__: + raise SymbolAlreadyExposedError( + 'Symbol %s is already exposed as %s.' % + (undecorated_func.__name__, getattr( + undecorated_func, api_names_attr))) # pylint: disable=protected-access + setattr(undecorated_func, api_names_attr, self._names) return func def export_constant(self, module_name, name): @@ -126,8 +140,12 @@ class tf_export(object): # pylint: disable=invalid-name name: (string) Current constant name. """ module = sys.modules[module_name] - if not hasattr(module, '_tf_api_constants'): - module._tf_api_constants = [] # pylint: disable=protected-access + if not hasattr(module, API_ATTRS[self._api_name].constants): + setattr(module, API_ATTRS[self._api_name].constants, []) # pylint: disable=protected-access - module._tf_api_constants.append((self._names, name)) + getattr(module, API_ATTRS[self._api_name].constants).append( + (self._names, name)) + +tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) +estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME) |