aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-26 10:01:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 10:05:09 -0700
commit883e8b6863511b46d7985e9ff8d1809ffe2a1bc0 (patch)
tree242a7257bfd31687720b7e08a5c7cb833d6c3b79 /tensorflow/python/util
parent86cffb1d9201f8072cef3eb13ef0dc524e0f4535 (diff)
Make function_utils.get_func_code more tolerant of strange objects like functool.partial.
PiperOrigin-RevId: 206175973
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/function_utils.py14
-rw-r--r--tensorflow/python/util/function_utils_test.py78
2 files changed, 87 insertions, 5 deletions
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
index 61312feafd..4e9b07e20a 100644
--- a/tensorflow/python/util/function_utils.py
+++ b/tensorflow/python/util/function_utils.py
@@ -69,20 +69,24 @@ def get_func_name(func):
return '%s.%s' % (six.get_method_self(func).__class__.__name__,
six.get_method_function(func).__name__)
else: # Probably a class instance with __call__
- return type(func)
+ return str(type(func))
else:
raise ValueError('Argument must be callable')
def get_func_code(func):
- """Returns func_code of passed callable."""
+ """Returns func_code of passed callable, or None if not available."""
_, func = tf_decorator.unwrap(func)
if callable(func):
if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
return six.get_function_code(func)
- elif hasattr(func, '__call__'):
+ # Since the object is not a function or method, but is a callable, we will
+ # try to access the __call__method as a function. This works with callable
+ # classes but fails with functool.partial objects despite their __call__
+ # attribute.
+ try:
return six.get_function_code(func.__call__)
- else:
- raise ValueError('Unhandled callable, type=%s' % type(func))
+ except AttributeError:
+ return None
else:
raise ValueError('Argument must be callable')
diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py
index e78cf6a5b0..1588328c26 100644
--- a/tensorflow/python/util/function_utils_test.py
+++ b/tensorflow/python/util/function_utils_test.py
@@ -24,6 +24,16 @@ from tensorflow.python.platform import test
from tensorflow.python.util import function_utils
+def silly_example_function():
+ pass
+
+
+class SillyCallableClass(object):
+
+ def __call__(self):
+ pass
+
+
class FnArgsTest(test.TestCase):
def test_simple_function(self):
@@ -124,5 +134,73 @@ class FnArgsTest(test.TestCase):
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
+
+class GetFuncNameTest(test.TestCase):
+
+ def testWithSimpleFunction(self):
+ self.assertEqual(
+ 'silly_example_function',
+ function_utils.get_func_name(silly_example_function))
+
+ def testWithClassMethod(self):
+ self.assertEqual(
+ 'GetFuncNameTest.testWithClassMethod',
+ function_utils.get_func_name(self.testWithClassMethod))
+
+ def testWithCallableClass(self):
+ callable_instance = SillyCallableClass()
+ self.assertRegexpMatches(
+ function_utils.get_func_name(callable_instance),
+ '<.*SillyCallableClass.*>')
+
+ def testWithFunctoolsPartial(self):
+ partial = functools.partial(silly_example_function)
+ self.assertRegexpMatches(
+ function_utils.get_func_name(partial),
+ '<.*functools.partial.*>')
+
+ def testWithLambda(self):
+ anon_fn = lambda x: x
+ self.assertEqual('<lambda>', function_utils.get_func_name(anon_fn))
+
+ def testRaisesWithNonCallableObject(self):
+ with self.assertRaises(ValueError):
+ function_utils.get_func_name(None)
+
+
+class GetFuncCodeTest(test.TestCase):
+
+ def testWithSimpleFunction(self):
+ code = function_utils.get_func_code(silly_example_function)
+ self.assertIsNotNone(code)
+ self.assertRegexpMatches(code.co_filename, 'function_utils_test.py')
+
+ def testWithClassMethod(self):
+ code = function_utils.get_func_code(self.testWithClassMethod)
+ self.assertIsNotNone(code)
+ self.assertRegexpMatches(code.co_filename, 'function_utils_test.py')
+
+ def testWithCallableClass(self):
+ callable_instance = SillyCallableClass()
+ code = function_utils.get_func_code(callable_instance)
+ self.assertIsNotNone(code)
+ self.assertRegexpMatches(code.co_filename, 'function_utils_test.py')
+
+ def testWithLambda(self):
+ anon_fn = lambda x: x
+ code = function_utils.get_func_code(anon_fn)
+ self.assertIsNotNone(code)
+ self.assertRegexpMatches(code.co_filename, 'function_utils_test.py')
+
+ def testWithFunctoolsPartial(self):
+ partial = functools.partial(silly_example_function)
+ code = function_utils.get_func_code(partial)
+ self.assertIsNone(code)
+
+ def testRaisesWithNonCallableObject(self):
+ with self.assertRaises(ValueError):
+ function_utils.get_func_code(None)
+
+
if __name__ == '__main__':
test.main()