diff options
author | 2017-10-26 19:06:43 -0700 | |
---|---|---|
committer | 2017-10-26 19:10:35 -0700 | |
commit | b113d082ac6320adaaa0205cd77ab815ff40bc16 (patch) | |
tree | be81843afe470219bd9cb5f717f77342acda5f7a /tensorflow/python/estimator/util.py | |
parent | abebb5f3fa6799e4fc1f2de1156a7c968c8473b8 (diff) |
Exclude 'self' from function arguments returned by util.fn_args for callables
and bounded methods.
PiperOrigin-RevId: 173622989
Diffstat (limited to 'tensorflow/python/estimator/util.py')
-rw-r--r-- | tensorflow/python/estimator/util.py | 39 |
1 files changed, 20 insertions, 19 deletions
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index de35e66bdf..12f2592d84 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -19,10 +19,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +def _is_bounded_method(fn): + return tf_inspect.ismethod(fn) and (fn.__self__ is not None) + + +def _is_callable_object(obj): + return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__) + + def fn_args(fn): """Get argument names for function-like object. @@ -36,22 +46,13 @@ def fn_args(fn): ValueError: if partial function has positionally bound arguments """ _, fn = tf_decorator.unwrap(fn) - - # Handle callables. - if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__): - return tuple(tf_inspect.getargspec(fn.__call__).args) - - # Handle functools.partial and similar objects. - if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'): - # Handle nested partial. - original_args = fn_args(fn.func) - if not original_args: - return tuple() - - return tuple([ - arg for arg in original_args[len(fn.args):] - if arg not in set((fn.keywords or {}).keys()) - ]) - - # Handle function. - return tuple(tf_inspect.getargspec(fn).args) + if isinstance(fn, functools.partial): + args = fn_args(fn.func) + args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] + else: + if _is_callable_object(fn): + fn = fn.__call__ + args = tf_inspect.getargspec(fn).args + if _is_bounded_method(fn): + args.remove('self') + return tuple(args) |