aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-07-10 14:33:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 14:37:04 -0700
commit17778e691b8e1cbfcab9ba211e56db1e6d7a9a0c (patch)
tree6b666bee930bb75cc46af2ce2fffe00344f4c615 /tensorflow/python/util
parent0a805f8d9fdf2e16e0866586bdfb9a6151395a85 (diff)
Make sure correct docs are generated when using @deprecated_alias
decorator. Specifically, wrap `NewClass.__init__` method using tf_decorator.make_decorator so that doc generation can pick up correct arguments for `__init__` instead of *args, **kwargs. Also, skip _NewClass when generating API goldens. PiperOrigin-RevId: 204013970
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/deprecation.py38
-rw-r--r--tensorflow/python/util/deprecation_test.py6
2 files changed, 36 insertions, 8 deletions
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 376be39978..c8ed2b715d 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -87,6 +87,27 @@ def _call_location(outer=False):
return '%s:%d' % (entry[1], entry[2])
+def _wrap_decorator(wrapped_function):
+ """Indicate that one function wraps another.
+
+ This decorator wraps a function using `tf_decorator.make_decorator`
+ so that doc generation scripts can pick up original function
+ signature.
+ It would be better to use @functools.wrap decorator, but it would
+ not update function signature to match wrapped function in Python 2.
+
+ Args:
+ wrapped_function: The function that decorated function wraps.
+
+ Returns:
+ Function that accepts wrapper function as an argument and returns
+ `TFDecorator` instance.
+ """
+ def wrapper(wrapper_func):
+ return tf_decorator.make_decorator(wrapped_function, wrapper_func)
+ return wrapper
+
+
def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
"""Deprecate a symbol in favor of a new name with identical semantics.
@@ -144,7 +165,7 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
if tf_inspect.isclass(func_or_class):
# Make a new class with __init__ wrapped in a warning.
- class NewClass(func_or_class): # pylint: disable=missing-docstring
+ class _NewClass(func_or_class): # pylint: disable=missing-docstring
__doc__ = decorator_utils.add_notice_to_docstring(
func_or_class.__doc__, 'Please use %s instead.' % name,
'DEPRECATED CLASS',
@@ -153,27 +174,28 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
__name__ = func_or_class.__name__
__module__ = _call_location(outer=True)
+ @_wrap_decorator(func_or_class.__init__)
def __init__(self, *args, **kwargs):
- if hasattr(NewClass.__init__, '__func__'):
+ if hasattr(_NewClass.__init__, '__func__'):
# Python 2
- NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
+ _NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
else:
# Python 3
- NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
+ _NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
if _PRINT_DEPRECATION_WARNINGS:
# We're making the alias as we speak. The original may have other
# aliases, so we cannot use it to check for whether it's already been
# warned about.
- if NewClass.__init__ not in _PRINTED_WARNING:
+ if _NewClass.__init__ not in _PRINTED_WARNING:
if warn_once:
- _PRINTED_WARNING[NewClass.__init__] = True
+ _PRINTED_WARNING[_NewClass.__init__] = True
logging.warning(
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), deprecated_name, name)
- super(NewClass, self).__init__(*args, **kwargs)
+ super(_NewClass, self).__init__(*args, **kwargs)
- return NewClass
+ return _NewClass
else:
decorator_utils.validate_callable(func_or_class, 'deprecated')
diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py
index bdd0bc48d2..1ea695e4d6 100644
--- a/tensorflow/python/util/deprecation_test.py
+++ b/tensorflow/python/util/deprecation_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
+from tensorflow.python.util import tf_inspect
class DeprecatedAliasTest(test.TestCase):
@@ -73,6 +74,11 @@ class DeprecatedAliasTest(test.TestCase):
self.assertEqual(["test", "deprecated", "deprecated again"],
MyClass.init_args)
+ # Check __init__ signature matches for doc generation.
+ self.assertEqual(
+ tf_inspect.getfullargspec(MyClass.__init__),
+ tf_inspect.getfullargspec(deprecated_cls.__init__))
+
class DeprecationTest(test.TestCase):