diff options
author | Anna R <annarev@google.com> | 2018-07-16 18:11:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-16 18:14:45 -0700 |
commit | d8f3425e5b054dff01b5ece80e8c8a101c4ed816 (patch) | |
tree | 5a6d888dc60afce0c93bbfd8a1c457bb687c71e7 /tensorflow/python/util | |
parent | 2c442d26f36a0f167685fd31b9ecdb4e290c2b29 (diff) |
Handle deprecated fields in api_def.proto.
Also update how canonical endpoint name is set in doc_generator_visitor.py.
PiperOrigin-RevId: 204841165
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/deprecation.py | 34 | ||||
-rw-r--r-- | tensorflow/python/util/deprecation_test.py | 22 | ||||
-rw-r--r-- | tensorflow/python/util/tf_export.py | 42 |
3 files changed, 98 insertions, 0 deletions
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index c8ed2b715d..9e2202eaf8 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -37,6 +37,11 @@ _PRINT_DEPRECATION_WARNINGS = True _PRINTED_WARNING = {} +class DeprecatedNamesAlreadySet(Exception): + """Raised when setting deprecated names multiple times for the same symbol.""" + pass + + def _add_deprecated_function_notice_to_docstring(doc, date, instructions): """Adds a deprecation notice to a docstring for deprecated functions.""" main_text = ['THIS FUNCTION IS DEPRECATED. It will be removed %s.' % @@ -219,6 +224,35 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): func_or_class.__doc__, None, 'Please use %s instead.' % name)) +def deprecated_endpoints(*args): + """Decorator for marking endpoints deprecated. + + This decorator does not print deprecation messages. + TODO(annarev): eventually start printing deprecation warnings when + @deprecation_endpoints decorator is added. + + Args: + *args: Deprecated endpoint names. + + Returns: + A function that takes symbol as an argument and adds + _tf_deprecated_api_names to that symbol. + _tf_deprecated_api_names would be set to a list of deprecated + endpoint names for the symbol. + """ + def deprecated_wrapper(func): + # pylint: disable=protected-access + if '_tf_deprecated_api_names' in func.__dict__: + raise DeprecatedNamesAlreadySet( + 'Cannot set deprecated names for %s to %s. ' + 'Deprecated names are already set to %s.' % ( + func.__name__, str(args), str(func._tf_deprecated_api_names))) + func._tf_deprecated_api_names = args + # pylint: disable=protected-access + return func + return deprecated_wrapper + + def deprecated(date, instructions, warn_once=True): """Decorator for marking functions or methods deprecated. diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index 1ea695e4d6..90c73a0a58 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -935,5 +935,27 @@ class DeprecationArgumentsTest(test.TestCase): self.assertEqual(new_docs, new_docs_ref) +class DeprecatedEndpointsTest(test.TestCase): + + def testSingleDeprecatedEndpoint(self): + @deprecation.deprecated_endpoints("foo1") + def foo(): + pass + self.assertEqual(("foo1",), foo._tf_deprecated_api_names) + + def testMultipleDeprecatedEndpoint(self): + @deprecation.deprecated_endpoints("foo1", "foo2") + def foo(): + pass + self.assertEqual(("foo1", "foo2"), foo._tf_deprecated_api_names) + + def testCannotSetDeprecatedEndpointsTwice(self): + with self.assertRaises(deprecation.DeprecatedNamesAlreadySet): + @deprecation.deprecated_endpoints("foo1") + @deprecation.deprecated_endpoints("foo2") + def foo(): # pylint: disable=unused-variable + pass + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index c362d588ab..274f32c21f 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -78,6 +78,48 @@ class SymbolAlreadyExposedError(Exception): pass +def get_canonical_name_for_symbol(symbol, api_name=TENSORFLOW_API_NAME): + """Get canonical name for the API symbol. + + Canonical name is the first non-deprecated endpoint name. + + Args: + symbol: API function or class. + api_name: API name (tensorflow or estimator). + + Returns: + Canonical name for the API symbol (for e.g. initializers.zeros) if + canonical name could be determined. Otherwise, returns None. + """ + if not hasattr(symbol, '__dict__'): + return None + api_names_attr = API_ATTRS[api_name].names + _, undecorated_symbol = tf_decorator.unwrap(symbol) + if api_names_attr not in undecorated_symbol.__dict__: + return None + api_names = getattr(undecorated_symbol, api_names_attr) + # TODO(annarev): may be add a separate deprecated attribute + # for estimator names. + deprecated_api_names = undecorated_symbol.__dict__.get( + '_tf_deprecated_api_names', []) + return get_canonical_name(api_names, deprecated_api_names) + + +def get_canonical_name(api_names, deprecated_api_names): + """Get first non-deprecated endpoint name. + + Args: + api_names: API names iterable. + deprecated_api_names: Deprecated API names iterable. + Returns: + Canonical name if there is at least one non-deprecated endpoint. + Otherwise returns None. + """ + return next( + (name for name in api_names if name not in deprecated_api_names), + None) + + class api_export(object): # pylint: disable=invalid-name """Provides ways to export symbols to the TensorFlow API.""" |