diff options
author | Anna R <annarev@google.com> | 2018-07-10 14:33:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 14:37:04 -0700 |
commit | 17778e691b8e1cbfcab9ba211e56db1e6d7a9a0c (patch) | |
tree | 6b666bee930bb75cc46af2ce2fffe00344f4c615 /tensorflow/python/util | |
parent | 0a805f8d9fdf2e16e0866586bdfb9a6151395a85 (diff) |
Make sure correct docs are generated when using @deprecated_alias
decorator. Specifically, wrap `NewClass.__init__` method using tf_decorator.make_decorator
so that doc generation can pick up correct arguments for `__init__` instead
of *args, **kwargs.
Also, skip _NewClass when generating API goldens.
PiperOrigin-RevId: 204013970
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/deprecation.py | 38 | ||||
-rw-r--r-- | tensorflow/python/util/deprecation_test.py | 6 |
2 files changed, 36 insertions, 8 deletions
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 376be39978..c8ed2b715d 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -87,6 +87,27 @@ def _call_location(outer=False): return '%s:%d' % (entry[1], entry[2]) +def _wrap_decorator(wrapped_function): + """Indicate that one function wraps another. + + This decorator wraps a function using `tf_decorator.make_decorator` + so that doc generation scripts can pick up original function + signature. + It would be better to use @functools.wrap decorator, but it would + not update function signature to match wrapped function in Python 2. + + Args: + wrapped_function: The function that decorated function wraps. + + Returns: + Function that accepts wrapper function as an argument and returns + `TFDecorator` instance. + """ + def wrapper(wrapper_func): + return tf_decorator.make_decorator(wrapped_function, wrapper_func) + return wrapper + + def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): """Deprecate a symbol in favor of a new name with identical semantics. @@ -144,7 +165,7 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): if tf_inspect.isclass(func_or_class): # Make a new class with __init__ wrapped in a warning. - class NewClass(func_or_class): # pylint: disable=missing-docstring + class _NewClass(func_or_class): # pylint: disable=missing-docstring __doc__ = decorator_utils.add_notice_to_docstring( func_or_class.__doc__, 'Please use %s instead.' % name, 'DEPRECATED CLASS', @@ -153,27 +174,28 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): __name__ = func_or_class.__name__ __module__ = _call_location(outer=True) + @_wrap_decorator(func_or_class.__init__) def __init__(self, *args, **kwargs): - if hasattr(NewClass.__init__, '__func__'): + if hasattr(_NewClass.__init__, '__func__'): # Python 2 - NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__ + _NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__ else: # Python 3 - NewClass.__init__.__doc__ = func_or_class.__init__.__doc__ + _NewClass.__init__.__doc__ = func_or_class.__init__.__doc__ if _PRINT_DEPRECATION_WARNINGS: # We're making the alias as we speak. The original may have other # aliases, so we cannot use it to check for whether it's already been # warned about. - if NewClass.__init__ not in _PRINTED_WARNING: + if _NewClass.__init__ not in _PRINTED_WARNING: if warn_once: - _PRINTED_WARNING[NewClass.__init__] = True + _PRINTED_WARNING[_NewClass.__init__] = True logging.warning( 'From %s: The name %s is deprecated. Please use %s instead.\n', _call_location(), deprecated_name, name) - super(NewClass, self).__init__(*args, **kwargs) + super(_NewClass, self).__init__(*args, **kwargs) - return NewClass + return _NewClass else: decorator_utils.validate_callable(func_or_class, 'deprecated') diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index bdd0bc48d2..1ea695e4d6 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +from tensorflow.python.util import tf_inspect class DeprecatedAliasTest(test.TestCase): @@ -73,6 +74,11 @@ class DeprecatedAliasTest(test.TestCase): self.assertEqual(["test", "deprecated", "deprecated again"], MyClass.init_args) + # Check __init__ signature matches for doc generation. + self.assertEqual( + tf_inspect.getfullargspec(MyClass.__init__), + tf_inspect.getfullargspec(deprecated_cls.__init__)) + class DeprecationTest(test.TestCase): |