aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-26 01:27:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 01:31:05 -0700
commite5cc33df74ec4f761da26c87bb785edfa3fb8280 (patch)
tree67dedd2463f703845bd0444fb6fac1b7c12a80aa /tensorflow/python/util
parentb1dc68e816e2bf6b8acd3651077c890f2f2f3b7b (diff)
Convert device function stack into TraceableStack for use in error message interpolation.
PiperOrigin-RevId: 206120307
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/function_utils.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
index 7bbbde3cd2..61312feafd 100644
--- a/tensorflow/python/util/function_utils.py
+++ b/tensorflow/python/util/function_utils.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import functools
+import six
+
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -55,3 +57,32 @@ def fn_args(fn):
if _is_bounded_method(fn):
args.remove('self')
return tuple(args)
+
+
+def get_func_name(func):
+ """Returns name of passed callable."""
+ _, func = tf_decorator.unwrap(func)
+ if callable(func):
+ if tf_inspect.isfunction(func):
+ return func.__name__
+ elif tf_inspect.ismethod(func):
+ return '%s.%s' % (six.get_method_self(func).__class__.__name__,
+ six.get_method_function(func).__name__)
+ else: # Probably a class instance with __call__
+ return type(func)
+ else:
+ raise ValueError('Argument must be callable')
+
+
+def get_func_code(func):
+ """Returns func_code of passed callable."""
+ _, func = tf_decorator.unwrap(func)
+ if callable(func):
+ if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
+ return six.get_function_code(func)
+ elif hasattr(func, '__call__'):
+ return six.get_function_code(func.__call__)
+ else:
+ raise ValueError('Unhandled callable, type=%s' % type(func))
+ else:
+ raise ValueError('Argument must be callable')