aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/util.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-26 19:06:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-26 19:10:35 -0700
commitb113d082ac6320adaaa0205cd77ab815ff40bc16 (patch)
treebe81843afe470219bd9cb5f717f77342acda5f7a /tensorflow/python/estimator/util.py
parentabebb5f3fa6799e4fc1f2de1156a7c968c8473b8 (diff)
Exclude 'self' from function arguments returned by util.fn_args for callables
and bounded methods. PiperOrigin-RevId: 173622989
Diffstat (limited to 'tensorflow/python/estimator/util.py')
-rw-r--r--tensorflow/python/estimator/util.py39
1 files changed, 20 insertions, 19 deletions
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index de35e66bdf..12f2592d84 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -19,10 +19,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
+def _is_bounded_method(fn):
+ return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
+
+
+def _is_callable_object(obj):
+ return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
+
+
def fn_args(fn):
"""Get argument names for function-like object.
@@ -36,22 +46,13 @@ def fn_args(fn):
ValueError: if partial function has positionally bound arguments
"""
_, fn = tf_decorator.unwrap(fn)
-
- # Handle callables.
- if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
- return tuple(tf_inspect.getargspec(fn.__call__).args)
-
- # Handle functools.partial and similar objects.
- if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
- # Handle nested partial.
- original_args = fn_args(fn.func)
- if not original_args:
- return tuple()
-
- return tuple([
- arg for arg in original_args[len(fn.args):]
- if arg not in set((fn.keywords or {}).keys())
- ])
-
- # Handle function.
- return tuple(tf_inspect.getargspec(fn).args)
+ if isinstance(fn, functools.partial):
+ args = fn_args(fn.func)
+ args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
+ else:
+ if _is_callable_object(fn):
+ fn = fn.__call__
+ args = tf_inspect.getargspec(fn).args
+ if _is_bounded_method(fn):
+ args.remove('self')
+ return tuple(args)