aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util/tf_export.py
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-06-07 12:05:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 12:14:50 -0700
commit501cf726cbee2ee13efef43884a6552ca211979d (patch)
tree2a93bae901b9f9d32f5d622e2e4d626668b48b99 /tensorflow/python/util/tf_export.py
parent4d0d60a82c52c6c71650db33bf826f03559d91fc (diff)
Internal Change.
PiperOrigin-RevId: 199673803
Diffstat (limited to 'tensorflow/python/util/tf_export.py')
-rw-r--r--tensorflow/python/util/tf_export.py58
1 files changed, 38 insertions, 20 deletions
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index bf3961c692..e154ffb68a 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -41,17 +41,35 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import functools
import sys
from tensorflow.python.util import tf_decorator
+ESTIMATOR_API_NAME = 'estimator'
+TENSORFLOW_API_NAME = 'tensorflow'
+
+_Attributes = collections.namedtuple(
+ 'ExportedApiAttributes', ['names', 'constants'])
+
+# Attribute values must be unique to each API.
+API_ATTRS = {
+ TENSORFLOW_API_NAME: _Attributes(
+ '_tf_api_names',
+ '_tf_api_constants'),
+ ESTIMATOR_API_NAME: _Attributes(
+ '_estimator_api_names',
+ '_estimator_api_constants')
+}
+
class SymbolAlreadyExposedError(Exception):
"""Raised when adding API names to symbol that already has API names."""
pass
-class tf_export(object): # pylint: disable=invalid-name
+class api_export(object): # pylint: disable=invalid-name
"""Provides ways to export symbols to the TensorFlow API."""
def __init__(self, *args, **kwargs):
@@ -63,15 +81,12 @@ class tf_export(object): # pylint: disable=invalid-name
overrides: List of symbols that this is overriding
(those overrided api exports will be removed). Note: passing overrides
has no effect on exporting a constant.
- allow_multiple_exports: Allows exporting the same symbol multiple
- times with multiple `tf_export` usages. Prefer however, to list all
- of the exported names in a single `tf_export` usage when possible.
-
+ api_name: Name of the API you want to generate (e.g. `tensorflow` or
+ `estimator`). Default is `tensorflow`.
"""
self._names = args
+ self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
- self._allow_multiple_exports = kwargs.get(
- 'allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -86,25 +101,24 @@ class tf_export(object): # pylint: disable=invalid-name
SymbolAlreadyExposedError: Raised when a symbol already has API names
and kwarg `allow_multiple_exports` not set.
"""
+ api_names_attr = API_ATTRS[self._api_name].names
+
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
- del undecorated_f._tf_api_names # pylint: disable=protected-access
+ delattr(undecorated_f, api_names_attr)
_, undecorated_func = tf_decorator.unwrap(func)
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
- if '_tf_api_names' in undecorated_func.__dict__:
- if self._allow_multiple_exports:
- undecorated_func._tf_api_names += self._names # pylint: disable=protected-access
- else:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (undecorated_func.__name__, undecorated_func._tf_api_names)) # pylint: disable=protected-access
- else:
- undecorated_func._tf_api_names = self._names # pylint: disable=protected-access
+ if api_names_attr in undecorated_func.__dict__:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (undecorated_func.__name__, getattr(
+ undecorated_func, api_names_attr))) # pylint: disable=protected-access
+ setattr(undecorated_func, api_names_attr, self._names)
return func
def export_constant(self, module_name, name):
@@ -126,8 +140,12 @@ class tf_export(object): # pylint: disable=invalid-name
name: (string) Current constant name.
"""
module = sys.modules[module_name]
- if not hasattr(module, '_tf_api_constants'):
- module._tf_api_constants = [] # pylint: disable=protected-access
+ if not hasattr(module, API_ATTRS[self._api_name].constants):
+ setattr(module, API_ATTRS[self._api_name].constants, [])
# pylint: disable=protected-access
- module._tf_api_constants.append((self._names, name))
+ getattr(module, API_ATTRS[self._api_name].constants).append(
+ (self._names, name))
+
+tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
+estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME)