aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-07-16 18:11:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 18:14:45 -0700
commitd8f3425e5b054dff01b5ece80e8c8a101c4ed816 (patch)
tree5a6d888dc60afce0c93bbfd8a1c457bb687c71e7 /tensorflow/python/util
parent2c442d26f36a0f167685fd31b9ecdb4e290c2b29 (diff)
Handle deprecated fields in api_def.proto.
Also update how canonical endpoint name is set in doc_generator_visitor.py. PiperOrigin-RevId: 204841165
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/deprecation.py34
-rw-r--r--tensorflow/python/util/deprecation_test.py22
-rw-r--r--tensorflow/python/util/tf_export.py42
3 files changed, 98 insertions, 0 deletions
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index c8ed2b715d..9e2202eaf8 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -37,6 +37,11 @@ _PRINT_DEPRECATION_WARNINGS = True
_PRINTED_WARNING = {}
+class DeprecatedNamesAlreadySet(Exception):
+ """Raised when setting deprecated names multiple times for the same symbol."""
+ pass
+
+
def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
"""Adds a deprecation notice to a docstring for deprecated functions."""
main_text = ['THIS FUNCTION IS DEPRECATED. It will be removed %s.' %
@@ -219,6 +224,35 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
func_or_class.__doc__, None, 'Please use %s instead.' % name))
+def deprecated_endpoints(*args):
+ """Decorator for marking endpoints deprecated.
+
+ This decorator does not print deprecation messages.
+ TODO(annarev): eventually start printing deprecation warnings when
+ @deprecation_endpoints decorator is added.
+
+ Args:
+ *args: Deprecated endpoint names.
+
+ Returns:
+ A function that takes symbol as an argument and adds
+ _tf_deprecated_api_names to that symbol.
+ _tf_deprecated_api_names would be set to a list of deprecated
+ endpoint names for the symbol.
+ """
+ def deprecated_wrapper(func):
+ # pylint: disable=protected-access
+ if '_tf_deprecated_api_names' in func.__dict__:
+ raise DeprecatedNamesAlreadySet(
+ 'Cannot set deprecated names for %s to %s. '
+ 'Deprecated names are already set to %s.' % (
+ func.__name__, str(args), str(func._tf_deprecated_api_names)))
+ func._tf_deprecated_api_names = args
+ # pylint: disable=protected-access
+ return func
+ return deprecated_wrapper
+
+
def deprecated(date, instructions, warn_once=True):
"""Decorator for marking functions or methods deprecated.
diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py
index 1ea695e4d6..90c73a0a58 100644
--- a/tensorflow/python/util/deprecation_test.py
+++ b/tensorflow/python/util/deprecation_test.py
@@ -935,5 +935,27 @@ class DeprecationArgumentsTest(test.TestCase):
self.assertEqual(new_docs, new_docs_ref)
+class DeprecatedEndpointsTest(test.TestCase):
+
+ def testSingleDeprecatedEndpoint(self):
+ @deprecation.deprecated_endpoints("foo1")
+ def foo():
+ pass
+ self.assertEqual(("foo1",), foo._tf_deprecated_api_names)
+
+ def testMultipleDeprecatedEndpoint(self):
+ @deprecation.deprecated_endpoints("foo1", "foo2")
+ def foo():
+ pass
+ self.assertEqual(("foo1", "foo2"), foo._tf_deprecated_api_names)
+
+ def testCannotSetDeprecatedEndpointsTwice(self):
+ with self.assertRaises(deprecation.DeprecatedNamesAlreadySet):
+ @deprecation.deprecated_endpoints("foo1")
+ @deprecation.deprecated_endpoints("foo2")
+ def foo(): # pylint: disable=unused-variable
+ pass
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index c362d588ab..274f32c21f 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -78,6 +78,48 @@ class SymbolAlreadyExposedError(Exception):
pass
+def get_canonical_name_for_symbol(symbol, api_name=TENSORFLOW_API_NAME):
+ """Get canonical name for the API symbol.
+
+ Canonical name is the first non-deprecated endpoint name.
+
+ Args:
+ symbol: API function or class.
+ api_name: API name (tensorflow or estimator).
+
+ Returns:
+ Canonical name for the API symbol (for e.g. initializers.zeros) if
+ canonical name could be determined. Otherwise, returns None.
+ """
+ if not hasattr(symbol, '__dict__'):
+ return None
+ api_names_attr = API_ATTRS[api_name].names
+ _, undecorated_symbol = tf_decorator.unwrap(symbol)
+ if api_names_attr not in undecorated_symbol.__dict__:
+ return None
+ api_names = getattr(undecorated_symbol, api_names_attr)
+ # TODO(annarev): may be add a separate deprecated attribute
+ # for estimator names.
+ deprecated_api_names = undecorated_symbol.__dict__.get(
+ '_tf_deprecated_api_names', [])
+ return get_canonical_name(api_names, deprecated_api_names)
+
+
+def get_canonical_name(api_names, deprecated_api_names):
+ """Get first non-deprecated endpoint name.
+
+ Args:
+ api_names: API names iterable.
+ deprecated_api_names: Deprecated API names iterable.
+ Returns:
+ Canonical name if there is at least one non-deprecated endpoint.
+ Otherwise returns None.
+ """
+ return next(
+ (name for name in api_names if name not in deprecated_api_names),
+ None)
+
+
class api_export(object): # pylint: disable=invalid-name
"""Provides ways to export symbols to the TensorFlow API."""