aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-10 12:46:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 12:49:15 -0700
commit2a9eef3836c71a595c5c86645d54ff74ea3c1812 (patch)
treefbcd58f992ea52e0d53ca8a4b504330e62559d23 /tensorflow/contrib/learn
parent9c5aaf325bac0b0e180e3b1fe1ed81a88ef2fd55 (diff)
Fix a bug about getting arguments of partial functions.
PiperOrigin-RevId: 196157095
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index e28e6854a5..339c4e0e36 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -1862,12 +1862,12 @@ def _get_arguments(func):
if hasattr(func, "__code__"):
# Regular function.
return tf_inspect.getargspec(func)
- elif hasattr(func, "__call__"):
- # Callable object.
- return _get_arguments(func.__call__)
elif hasattr(func, "func"):
# Partial function.
return _get_arguments(func.func)
+ elif hasattr(func, "__call__"):
+ # Callable object.
+ return _get_arguments(func.__call__)
def _verify_loss_fn_args(loss_fn):