From d1588d72a820423cab36977ca97221aba01be713 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 8 Oct 2018 10:43:03 -0700 Subject: Add a utility that allows finding a name for an entity, relative to an existing namespace. PiperOrigin-RevId: 216211286 --- tensorflow/python/autograph/pyct/inspect_utils.py | 34 ++++++++++++++++++++++ .../python/autograph/pyct/inspect_utils_test.py | 19 ++++++++++++ 2 files changed, 53 insertions(+) diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py index 1416988ea3..29c406c248 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils.py +++ b/tensorflow/python/autograph/pyct/inspect_utils.py @@ -67,6 +67,40 @@ def getnamespace(f): return namespace +def getqualifiedname(namespace, object_, max_depth=2): + """Returns the name by which a value can be referred to in a given namespace. + + This function will recurse inside modules, but it will not search objects for + attributes. The recursion depth is controlled by max_depth. + + Args: + namespace: Dict[str, Any], the namespace to search into. + object_: Any, the value to search. + max_depth: Optional[int], a limit to the recursion depth when searching + inside modules. + Returns: Union[str, None], the fully-qualified name that resolves to the value + o, or None if it couldn't be found. + """ + for name, value in namespace.items(): + # The value may be referenced by more than one symbol, case in which + # any symbol will be fine. If the program contains symbol aliases that + # change over time, this may capture a symbol that will later point to + # something else. + # TODO(mdan): Prefer the symbol that matches the value type name. + if object_ is value: + return name + + # TODO(mdan): Use breadth-first search and avoid visiting modules twice. + if max_depth: + for name, value in namespace.items(): + if tf_inspect.ismodule(value): + name_in_module = getqualifiedname(value.__dict__, object_, + max_depth - 1) + if name_in_module is not None: + return '{}.{}'.format(name, name_in_module) + return None + + def _get_unbound_function(m): # TODO(mdan): Figure out why six.get_unbound_function fails in some cases. # The failure case is for tf.keras.Model. diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py index f3eb027822..11074debfc 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from functools import wraps +import imp import six @@ -127,6 +128,24 @@ class InspectUtilsTest(test.TestCase): self.assertEqual(ns['closed_over_primitive'], closed_over_primitive) self.assertTrue('local_var' not in ns) + def test_getqualifiedname(self): + foo = object() + qux = imp.new_module('quxmodule') + bar = imp.new_module('barmodule') + baz = object() + bar.baz = baz + + ns = { + 'foo': foo, + 'bar': bar, + 'qux': qux, + } + + self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils)) + self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo') + self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar') + self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz') + def test_getmethodclass(self): self.assertEqual( -- cgit v1.2.3