aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-08 10:43:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 10:55:00 -0700
commitd1588d72a820423cab36977ca97221aba01be713 (patch)
tree79e3eb6b5eedb37a8ac457076366d49c6450c2ac
parent8ef3e7c8c053cb6dad530e13c478bbd406ea2c95 (diff)
Add a utility that allows finding a name for an entity, relative to an existing namespace.
PiperOrigin-RevId: 216211286
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils.py34
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils_test.py19
2 files changed, 53 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index 1416988ea3..29c406c248 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -67,6 +67,40 @@ def getnamespace(f):
return namespace
+def getqualifiedname(namespace, object_, max_depth=2):
+ """Returns the name by which a value can be referred to in a given namespace.
+
+ This function will recurse inside modules, but it will not search objects for
+ attributes. The recursion depth is controlled by max_depth.
+
+ Args:
+ namespace: Dict[str, Any], the namespace to search into.
+ object_: Any, the value to search.
+ max_depth: Optional[int], a limit to the recursion depth when searching
+ inside modules.
+ Returns: Union[str, None], the fully-qualified name that resolves to the value
+ o, or None if it couldn't be found.
+ """
+ for name, value in namespace.items():
+ # The value may be referenced by more than one symbol, case in which
+ # any symbol will be fine. If the program contains symbol aliases that
+ # change over time, this may capture a symbol that will later point to
+ # something else.
+ # TODO(mdan): Prefer the symbol that matches the value type name.
+ if object_ is value:
+ return name
+
+ # TODO(mdan): Use breadth-first search and avoid visiting modules twice.
+ if max_depth:
+ for name, value in namespace.items():
+ if tf_inspect.ismodule(value):
+ name_in_module = getqualifiedname(value.__dict__, object_,
+ max_depth - 1)
+ if name_in_module is not None:
+ return '{}.{}'.format(name, name_in_module)
+ return None
+
+
def _get_unbound_function(m):
# TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
# The failure case is for tf.keras.Model.
diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
index f3eb027822..11074debfc 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from functools import wraps
+import imp
import six
@@ -127,6 +128,24 @@ class InspectUtilsTest(test.TestCase):
self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
self.assertTrue('local_var' not in ns)
+ def test_getqualifiedname(self):
+ foo = object()
+ qux = imp.new_module('quxmodule')
+ bar = imp.new_module('barmodule')
+ baz = object()
+ bar.baz = baz
+
+ ns = {
+ 'foo': foo,
+ 'bar': bar,
+ 'qux': qux,
+ }
+
+ self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
+ self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo')
+ self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
+ self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
+
def test_getmethodclass(self):
self.assertEqual(