aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/function.py10
-rw-r--r--tensorflow/python/eager/function_test.py13
2 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 54363ffcba..f315fa296c 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -960,11 +960,17 @@ class _PolymorphicFunction(object):
self._lock = threading.Lock()
fullargspec = tf_inspect.getfullargspec(self._python_function)
+ if tf_inspect.ismethod(self._python_function):
+ # Remove `self`: default arguments shouldn't be matched to it.
+ args = fullargspec.args[1:]
+ else:
+ args = fullargspec.args
+
# A cache mapping from argument name to index, for canonicalizing
# arguments that are called in a keyword-like fashion.
- self._args_to_indices = {arg: i for i, arg in enumerate(fullargspec.args)}
+ self._args_to_indices = {arg: i for i, arg in enumerate(args)}
# A cache mapping from arg index to default value, for canonicalization.
- offset = len(fullargspec.args) - len(fullargspec.defaults or [])
+ offset = len(args) - len(fullargspec.defaults or [])
self._arg_indices_to_default_values = {
offset + index: default
for index, default in enumerate(fullargspec.defaults or [])
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index b9e29635f8..b7c9334c33 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1233,6 +1233,19 @@ class FunctionTest(test.TestCase):
self.assertEqual(one.numpy(), 1.0)
self.assertEqual(two.numpy(), 2)
+ def testDefuningInstanceMethodWithDefaultArgument(self):
+
+ integer = constant_op.constant(2, dtypes.int64)
+
+ class Foo(object):
+
+ @function.defun
+ def func(self, other=integer):
+ return other
+
+ foo = Foo()
+ self.assertEqual(foo.func().numpy(), int(integer))
+
def testPythonCallWithSideEffects(self):
state = []