diff options
author | Michael Case <mikecase@google.com> | 2018-08-23 12:12:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-23 12:15:32 -0700 |
commit | 228d86a075e215cf308d97b8a12bee83ba09328d (patch) | |
tree | 803b6f1b8ef3d2050f65374ad862bb9173f83baf /tensorflow/python/util | |
parent | ca80e798a8d6afa998b24941ad261109bd23a081 (diff) |
Internal Change.
PiperOrigin-RevId: 209977319
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/tf_export.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index 2be4dbb283..a5ac430ce7 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -136,11 +136,14 @@ class api_export(object): # pylint: disable=invalid-name has no effect on exporting a constant. api_name: Name of the API you want to generate (e.g. `tensorflow` or `estimator`). Default is `tensorflow`. + allow_multiple_exports: Allow symbol to be exported multiple time under + different names. """ self._names = args self._names_v1 = kwargs.get('v1', args) self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) self._overrides = kwargs.get('overrides', []) + self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False) def __call__(self, func): """Calls this decorator. @@ -173,9 +176,10 @@ class api_export(object): # pylint: disable=invalid-name # __dict__ instead of using hasattr to verify that subclasses have # their own _tf_api_names as opposed to just inheriting it. if api_names_attr in func.__dict__: - raise SymbolAlreadyExposedError( - 'Symbol %s is already exposed as %s.' % - (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access + if not self._allow_multiple_exports: + raise SymbolAlreadyExposedError( + 'Symbol %s is already exposed as %s.' % + (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access setattr(func, api_names_attr, names) def export_constant(self, module_name, name): @@ -213,4 +217,5 @@ class api_export(object): # pylint: disable=invalid-name tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) -estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME) +estimator_export = functools.partial( + api_export, api_name=ESTIMATOR_API_NAME, allow_multiple_exports=True) |