aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 16:36:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 16:36:45 -0700
commit05973093a4716f861db2490dab2bcb8b9a6ee557 (patch)
tree61f38a2e01908bd5cf2351071ad846706a642bde /tensorflow/python/util
parent6663959a8a2dd93a4dab9b049767d64761a00adc (diff)
parentefe17306442aa91192df953ae537d3f9b824dae6 (diff)
Merge pull request #22517 from IMBurbank:master
PiperOrigin-RevId: 215480021
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/tf_inspect.py93
-rw-r--r--tensorflow/python/util/tf_inspect_test.py199
2 files changed, 247 insertions, 45 deletions
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 967c872c2a..444e44eaf1 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -36,6 +36,55 @@ else:
'annotations'
])
+if hasattr(_inspect, 'getfullargspec'):
+ _getfullargspec = _inspect.getfullargspec # pylint: disable=invalid-name
+
+ def _getargspec(target):
+ """A python3 version of getargspec.
+
+ Calls `getfullargspec` and assigns args, varargs,
+ varkw, and defaults to a python 2/3 compatible `ArgSpec`.
+
+ The parameter name 'varkw' is changed to 'keywords' to fit the
+ `ArgSpec` struct.
+
+ Args:
+ target: the target object to inspect.
+
+ Returns:
+ An ArgSpec with args, varargs, keywords, and defaults parameters
+ from FullArgSpec.
+ """
+ fullargspecs = getfullargspec(target)
+ argspecs = ArgSpec(
+ args=fullargspecs.args,
+ varargs=fullargspecs.varargs,
+ keywords=fullargspecs.varkw,
+ defaults=fullargspecs.defaults)
+ return argspecs
+else:
+ _getargspec = _inspect.getargspec
+
+ def _getfullargspec(target):
+ """A python2 version of getfullargspec.
+
+ Args:
+ target: the target object to inspect.
+
+ Returns:
+ A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
+ """
+ argspecs = getargspec(target)
+ fullargspecs = FullArgSpec(
+ args=argspecs.args,
+ varargs=argspecs.varargs,
+ varkw=argspecs.keywords,
+ defaults=argspecs.defaults,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+ return fullargspecs
+
def currentframe():
"""TFDecorator-aware replacement for inspect.currentframe."""
@@ -43,16 +92,18 @@ def currentframe():
def getargspec(obj):
- """TFDecorator-aware replacement for inspect.getargspec.
+ """TFDecorator-aware replacement for `inspect.getargspec`.
+
+ Note: `getfullargspec` is recommended as the python 2/3 compatible
+ replacement for this function.
Args:
- obj: A function, partial function, or callable object, possibly
- decorated.
+ obj: A function, partial function, or callable object, possibly decorated.
Returns:
The `ArgSpec` that describes the signature of the outermost decorator that
- changes the callable's signature. If the callable is not decorated,
- `inspect.getargspec()` will be called directly on the object.
+ changes the callable's signature, or the `ArgSpec` that describes
+ the object if not decorated.
Raises:
ValueError: When callable's signature can not be expressed with
@@ -72,24 +123,24 @@ def getargspec(obj):
try:
# Python3 will handle most callables here (not partial).
- return _inspect.getargspec(target)
+ return _getargspec(target)
except TypeError:
pass
if isinstance(target, type):
try:
- return _inspect.getargspec(target.__init__)
+ return _getargspec(target.__init__)
except TypeError:
pass
try:
- return _inspect.getargspec(target.__new__)
+ return _getargspec(target.__new__)
except TypeError:
pass
# The `type(target)` ensures that if a class is received we don't return
# the signature of it's __call__ method.
- return _inspect.getargspec(type(target).__call__)
+ return _getargspec(type(target).__call__)
def _get_argspec_for_partial(obj):
@@ -172,30 +223,6 @@ def _get_argspec_for_partial(obj):
return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
-if hasattr(_inspect, 'getfullargspec'):
- _getfullargspec = _inspect.getfullargspec
-else:
-
- def _getfullargspec(target):
- """A python2 version of getfullargspec.
-
- Args:
- target: the target object to inspect.
- Returns:
- A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
- """
- argspecs = getargspec(target)
- fullargspecs = FullArgSpec(
- args=argspecs.args,
- varargs=argspecs.varargs,
- varkw=argspecs.keywords,
- defaults=argspecs.defaults,
- kwonlyargs=[],
- kwonlydefaults=None,
- annotations={})
- return fullargspecs
-
-
def getfullargspec(obj):
"""TFDecorator-aware replacement for `inspect.getfullargspec`.
diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py
index d3b7e4b969..02d075cdff 100644
--- a/tensorflow/python/util/tf_inspect_test.py
+++ b/tensorflow/python/util/tf_inspect_test.py
@@ -122,18 +122,6 @@ class TfInspectTest(test.TestCase):
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
- def testGetFullArgsSpecForPartial(self):
-
- def func(a, b):
- del a, b
-
- partial_function = functools.partial(func, 1)
- argspec = tf_inspect.FullArgSpec(
- args=['b'], varargs=None, varkw=None, defaults=None,
- kwonlyargs=[], kwonlydefaults=None, annotations={})
-
- self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function))
-
def testGetArgSpecOnPartialInvalidArgspec(self):
"""Tests getargspec on partial function that doesn't have valid argspec."""
@@ -303,6 +291,193 @@ class TfInspectTest(test.TestCase):
self.assertEqual(argspec, tf_inspect.getargspec(NewClass))
+ def testGetFullArgSpecOnDecoratorsThatDontProvideFullArgSpec(self):
+ argspec = tf_inspect.getfullargspec(test_decorated_function_with_defaults)
+ self.assertEqual(['a', 'b', 'c'], argspec.args)
+ self.assertEqual((2, 'Hello'), argspec.defaults)
+
+ def testGetFullArgSpecOnDecoratorThatChangesFullArgSpec(self):
+ argspec = tf_inspect.FullArgSpec(
+ args=['a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
+ argspec)
+ self.assertEqual(argspec, tf_inspect.getfullargspec(decorator))
+
+ def testGetFullArgSpecIgnoresDecoratorsThatDontProvideFullArgSpec(self):
+ argspec = tf_inspect.FullArgSpec(
+ args=['a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
+ '', argspec)
+ outer_decorator = tf_decorator.TFDecorator('', inner_decorator)
+ self.assertEqual(argspec, tf_inspect.getfullargspec(outer_decorator))
+
+ def testGetFullArgSpecReturnsOutermostDecoratorThatChangesFullArgSpec(self):
+ outer_argspec = tf_inspect.FullArgSpec(
+ args=['a'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+ inner_argspec = tf_inspect.FullArgSpec(
+ args=['b'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
+ '', inner_argspec)
+ outer_decorator = tf_decorator.TFDecorator('', inner_decorator, '',
+ outer_argspec)
+ self.assertEqual(outer_argspec, tf_inspect.getfullargspec(outer_decorator))
+
+ def testGetFullArgsSpecForPartial(self):
+
+ def func(a, b):
+ del a, b
+
+ partial_function = functools.partial(func, 1)
+ argspec = tf_inspect.FullArgSpec(
+ args=['b'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function))
+
+ def testGetFullArgSpecOnPartialNoArgumentsLeft(self):
+ """Tests getfullargspec on partial function that prunes all arguments."""
+
+ def func(m, n):
+ return 2 * m + n
+
+ partial_func = functools.partial(func, 7, 10)
+ argspec = tf_inspect.FullArgSpec(
+ args=[],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnPartialWithVarargs(self):
+ """Tests getfullargspec on partial function with variable arguments."""
+
+ def func(m, *arg):
+ return m + len(arg)
+
+ partial_func = functools.partial(func, 7, 8)
+ argspec = tf_inspect.FullArgSpec(
+ args=[],
+ varargs='arg',
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnPartialWithVarkwargs(self):
+ """Tests getfullargspec.
+
+ Tests on partial function with variable keyword arguments.
+ """
+
+ def func(m, n, **kwarg):
+ return m * n + len(kwarg)
+
+ partial_func = functools.partial(func, 7)
+ argspec = tf_inspect.FullArgSpec(
+ args=['n'],
+ varargs=None,
+ varkw='kwarg',
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnCallableObject(self):
+
+ class Callable(object):
+
+ def __call__(self, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['self', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ test_obj = Callable()
+ self.assertEqual(argspec, tf_inspect.getfullargspec(test_obj))
+
+ def testGetFullArgSpecOnInitClass(self):
+
+ class InitClass(object):
+
+ def __init__(self, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['self', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(InitClass))
+
+ def testGetFullArgSpecOnNewClass(self):
+
+ class NewClass(object):
+
+ def __new__(cls, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['cls', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(NewClass))
+
def testGetDoc(self):
self.assertEqual('Test Decorated Function With Defaults Docstring.',
tf_inspect.getdoc(test_decorated_function_with_defaults))