diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 16:36:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 16:36:45 -0700 |
commit | 05973093a4716f861db2490dab2bcb8b9a6ee557 (patch) | |
tree | 61f38a2e01908bd5cf2351071ad846706a642bde /tensorflow/python/util | |
parent | 6663959a8a2dd93a4dab9b049767d64761a00adc (diff) | |
parent | efe17306442aa91192df953ae537d3f9b824dae6 (diff) |
Merge pull request #22517 from IMBurbank:master
PiperOrigin-RevId: 215480021
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/tf_inspect.py | 93 | ||||
-rw-r--r-- | tensorflow/python/util/tf_inspect_test.py | 199 |
2 files changed, 247 insertions, 45 deletions
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index 967c872c2a..444e44eaf1 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -36,6 +36,55 @@ else: 'annotations' ]) +if hasattr(_inspect, 'getfullargspec'): + _getfullargspec = _inspect.getfullargspec # pylint: disable=invalid-name + + def _getargspec(target): + """A python3 version of getargspec. + + Calls `getfullargspec` and assigns args, varargs, + varkw, and defaults to a python 2/3 compatible `ArgSpec`. + + The parameter name 'varkw' is changed to 'keywords' to fit the + `ArgSpec` struct. + + Args: + target: the target object to inspect. + + Returns: + An ArgSpec with args, varargs, keywords, and defaults parameters + from FullArgSpec. + """ + fullargspecs = getfullargspec(target) + argspecs = ArgSpec( + args=fullargspecs.args, + varargs=fullargspecs.varargs, + keywords=fullargspecs.varkw, + defaults=fullargspecs.defaults) + return argspecs +else: + _getargspec = _inspect.getargspec + + def _getfullargspec(target): + """A python2 version of getfullargspec. + + Args: + target: the target object to inspect. + + Returns: + A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations. + """ + argspecs = getargspec(target) + fullargspecs = FullArgSpec( + args=argspecs.args, + varargs=argspecs.varargs, + varkw=argspecs.keywords, + defaults=argspecs.defaults, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + return fullargspecs + def currentframe(): """TFDecorator-aware replacement for inspect.currentframe.""" @@ -43,16 +92,18 @@ def currentframe(): def getargspec(obj): - """TFDecorator-aware replacement for inspect.getargspec. + """TFDecorator-aware replacement for `inspect.getargspec`. + + Note: `getfullargspec` is recommended as the python 2/3 compatible + replacement for this function. Args: - obj: A function, partial function, or callable object, possibly - decorated. + obj: A function, partial function, or callable object, possibly decorated. Returns: The `ArgSpec` that describes the signature of the outermost decorator that - changes the callable's signature. If the callable is not decorated, - `inspect.getargspec()` will be called directly on the object. + changes the callable's signature, or the `ArgSpec` that describes + the object if not decorated. Raises: ValueError: When callable's signature can not be expressed with @@ -72,24 +123,24 @@ def getargspec(obj): try: # Python3 will handle most callables here (not partial). - return _inspect.getargspec(target) + return _getargspec(target) except TypeError: pass if isinstance(target, type): try: - return _inspect.getargspec(target.__init__) + return _getargspec(target.__init__) except TypeError: pass try: - return _inspect.getargspec(target.__new__) + return _getargspec(target.__new__) except TypeError: pass # The `type(target)` ensures that if a class is received we don't return # the signature of it's __call__ method. - return _inspect.getargspec(type(target).__call__) + return _getargspec(type(target).__call__) def _get_argspec_for_partial(obj): @@ -172,30 +223,6 @@ def _get_argspec_for_partial(obj): return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:])) -if hasattr(_inspect, 'getfullargspec'): - _getfullargspec = _inspect.getfullargspec -else: - - def _getfullargspec(target): - """A python2 version of getfullargspec. - - Args: - target: the target object to inspect. - Returns: - A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations. - """ - argspecs = getargspec(target) - fullargspecs = FullArgSpec( - args=argspecs.args, - varargs=argspecs.varargs, - varkw=argspecs.keywords, - defaults=argspecs.defaults, - kwonlyargs=[], - kwonlydefaults=None, - annotations={}) - return fullargspecs - - def getfullargspec(obj): """TFDecorator-aware replacement for `inspect.getfullargspec`. diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py index d3b7e4b969..02d075cdff 100644 --- a/tensorflow/python/util/tf_inspect_test.py +++ b/tensorflow/python/util/tf_inspect_test.py @@ -122,18 +122,6 @@ class TfInspectTest(test.TestCase): self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) - def testGetFullArgsSpecForPartial(self): - - def func(a, b): - del a, b - - partial_function = functools.partial(func, 1) - argspec = tf_inspect.FullArgSpec( - args=['b'], varargs=None, varkw=None, defaults=None, - kwonlyargs=[], kwonlydefaults=None, annotations={}) - - self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function)) - def testGetArgSpecOnPartialInvalidArgspec(self): """Tests getargspec on partial function that doesn't have valid argspec.""" @@ -303,6 +291,193 @@ class TfInspectTest(test.TestCase): self.assertEqual(argspec, tf_inspect.getargspec(NewClass)) + def testGetFullArgSpecOnDecoratorsThatDontProvideFullArgSpec(self): + argspec = tf_inspect.getfullargspec(test_decorated_function_with_defaults) + self.assertEqual(['a', 'b', 'c'], argspec.args) + self.assertEqual((2, 'Hello'), argspec.defaults) + + def testGetFullArgSpecOnDecoratorThatChangesFullArgSpec(self): + argspec = tf_inspect.FullArgSpec( + args=['a', 'b', 'c'], + varargs=None, + varkw=None, + defaults=(1, 'hello'), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + decorator = tf_decorator.TFDecorator('', test_undecorated_function, '', + argspec) + self.assertEqual(argspec, tf_inspect.getfullargspec(decorator)) + + def testGetFullArgSpecIgnoresDecoratorsThatDontProvideFullArgSpec(self): + argspec = tf_inspect.FullArgSpec( + args=['a', 'b', 'c'], + varargs=None, + varkw=None, + defaults=(1, 'hello'), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function, + '', argspec) + outer_decorator = tf_decorator.TFDecorator('', inner_decorator) + self.assertEqual(argspec, tf_inspect.getfullargspec(outer_decorator)) + + def testGetFullArgSpecReturnsOutermostDecoratorThatChangesFullArgSpec(self): + outer_argspec = tf_inspect.FullArgSpec( + args=['a'], + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + inner_argspec = tf_inspect.FullArgSpec( + args=['b'], + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function, + '', inner_argspec) + outer_decorator = tf_decorator.TFDecorator('', inner_decorator, '', + outer_argspec) + self.assertEqual(outer_argspec, tf_inspect.getfullargspec(outer_decorator)) + + def testGetFullArgsSpecForPartial(self): + + def func(a, b): + del a, b + + partial_function = functools.partial(func, 1) + argspec = tf_inspect.FullArgSpec( + args=['b'], + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function)) + + def testGetFullArgSpecOnPartialNoArgumentsLeft(self): + """Tests getfullargspec on partial function that prunes all arguments.""" + + def func(m, n): + return 2 * m + n + + partial_func = functools.partial(func, 7, 10) + argspec = tf_inspect.FullArgSpec( + args=[], + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func)) + + def testGetFullArgSpecOnPartialWithVarargs(self): + """Tests getfullargspec on partial function with variable arguments.""" + + def func(m, *arg): + return m + len(arg) + + partial_func = functools.partial(func, 7, 8) + argspec = tf_inspect.FullArgSpec( + args=[], + varargs='arg', + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func)) + + def testGetFullArgSpecOnPartialWithVarkwargs(self): + """Tests getfullargspec. + + Tests on partial function with variable keyword arguments. + """ + + def func(m, n, **kwarg): + return m * n + len(kwarg) + + partial_func = functools.partial(func, 7) + argspec = tf_inspect.FullArgSpec( + args=['n'], + varargs=None, + varkw='kwarg', + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func)) + + def testGetFullArgSpecOnCallableObject(self): + + class Callable(object): + + def __call__(self, a, b=1, c='hello'): + pass + + argspec = tf_inspect.FullArgSpec( + args=['self', 'a', 'b', 'c'], + varargs=None, + varkw=None, + defaults=(1, 'hello'), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + test_obj = Callable() + self.assertEqual(argspec, tf_inspect.getfullargspec(test_obj)) + + def testGetFullArgSpecOnInitClass(self): + + class InitClass(object): + + def __init__(self, a, b=1, c='hello'): + pass + + argspec = tf_inspect.FullArgSpec( + args=['self', 'a', 'b', 'c'], + varargs=None, + varkw=None, + defaults=(1, 'hello'), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + self.assertEqual(argspec, tf_inspect.getfullargspec(InitClass)) + + def testGetFullArgSpecOnNewClass(self): + + class NewClass(object): + + def __new__(cls, a, b=1, c='hello'): + pass + + argspec = tf_inspect.FullArgSpec( + args=['cls', 'a', 'b', 'c'], + varargs=None, + varkw=None, + defaults=(1, 'hello'), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + self.assertEqual(argspec, tf_inspect.getfullargspec(NewClass)) + def testGetDoc(self): self.assertEqual('Test Decorated Function With Defaults Docstring.', tf_inspect.getdoc(test_decorated_function_with_defaults)) |