aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-08-23 12:12:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 12:15:32 -0700
commit228d86a075e215cf308d97b8a12bee83ba09328d (patch)
tree803b6f1b8ef3d2050f65374ad862bb9173f83baf /tensorflow/python/util
parentca80e798a8d6afa998b24941ad261109bd23a081 (diff)
Internal Change.
PiperOrigin-RevId: 209977319
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/tf_export.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index 2be4dbb283..a5ac430ce7 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -136,11 +136,14 @@ class api_export(object): # pylint: disable=invalid-name
has no effect on exporting a constant.
api_name: Name of the API you want to generate (e.g. `tensorflow` or
`estimator`). Default is `tensorflow`.
+ allow_multiple_exports: Allow symbol to be exported multiple time under
+ different names.
"""
self._names = args
self._names_v1 = kwargs.get('v1', args)
self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
+ self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -173,9 +176,10 @@ class api_export(object): # pylint: disable=invalid-name
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
if api_names_attr in func.__dict__:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
+ if not self._allow_multiple_exports:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
setattr(func, api_names_attr, names)
def export_constant(self, module_name, name):
@@ -213,4 +217,5 @@ class api_export(object): # pylint: disable=invalid-name
tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
-estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME)
+estimator_export = functools.partial(
+ api_export, api_name=ESTIMATOR_API_NAME, allow_multiple_exports=True)