aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/common
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2016-12-27 15:26:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-27 15:45:55 -0800
commit8ac4834beeb7e186d0a1c3794fdc178fa3553d3b (patch)
tree26340eba236c058c39dd6db555b1268bbbe4fbe3 /tensorflow/tools/common
parentc9722179a8ae49ca1438c957e51c737ed713087f (diff)
Make a traversal tool to visit everything in a given Python module/class.
Change: 143061298
Diffstat (limited to 'tensorflow/tools/common')
-rw-r--r--tensorflow/tools/common/BUILD37
-rw-r--r--tensorflow/tools/common/traverse.py91
-rw-r--r--tensorflow/tools/common/traverse_test.py84
3 files changed, 212 insertions, 0 deletions
diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD
new file mode 100644
index 0000000000..f1d43134b8
--- /dev/null
+++ b/tensorflow/tools/common/BUILD
@@ -0,0 +1,37 @@
+# Description:
+# Common functionality for TensorFlow tooling
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(
+ default_visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "traverse",
+ srcs = ["traverse.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "traverse_test",
+ srcs = ["traverse_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":traverse",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
diff --git a/tensorflow/tools/common/traverse.py b/tensorflow/tools/common/traverse.py
new file mode 100644
index 0000000000..443838d968
--- /dev/null
+++ b/tensorflow/tools/common/traverse.py
@@ -0,0 +1,91 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Traversing Python modules and classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import inspect
+import sys
+
+
+__all__ = ['traverse']
+
+
+def _traverse_internal(root, visit, stack, path):
+ """Internal helper for traverse."""
+
+ # Only traverse modules and classes
+ if not inspect.isclass(root) and not inspect.ismodule(root):
+ return
+
+ try:
+ children = inspect.getmembers(root)
+ except ImportError:
+ # On some Python installations, some modules do not support enumerating
+ # members (six in particular), leading to import errors.
+ children = []
+
+ new_stack = stack + [root]
+ visit(path, root, children)
+ for name, child in children:
+ # Do not descend into built-in modules
+ if inspect.ismodule(child) and child.__name__ in sys.builtin_module_names:
+ continue
+
+ # Break cycles
+ if any(child is item for item in new_stack): # `in`, but using `is`
+ continue
+
+ child_path = path + '.' + name if path else name
+ _traverse_internal(child, visit, new_stack, child_path)
+
+
+def traverse(root, visit):
+ """Recursively enumerate all members of `root`.
+
+ Similar to the Python library function `os.path.walk`.
+
+ Traverses the tree of Python objects starting with `root`, depth first.
+ Parent-child relationships in the tree are defined by membership in modules or
+ classes. The function `visit` is called with arguments
+ `(path, parent, children)` for each module or class `parent` found in the tree
+ of python objects starting with `root`. `path` is a string containing the name
+ with which `parent` is reachable from the current context. For example, if
+ `root` is a local class called `X` which contains a class `Y`, `visit` will be
+ called with `('Y', X.Y, children)`).
+
+ If `root` is not a module or class, `visit` is never called. `traverse`
+ never descends into built-in modules.
+
+ `children`, a list of `(name, object)` pairs are determined by
+ `inspect.getmembers`. To avoid visiting parts of the tree, `children` can be
+ modified in place, using `del` or slice assignment.
+
+ Cycles (determined by reference equality, `is`) stop the traversal. A stack of
+ objects is kept to find cycles. Objects forming cycles may appear in
+ `children`, but `visit` will not be called with any object as `parent` which
+ is already in the stack.
+
+ Traversing system modules can take a long time, it is advisable to pass a
+ `visit` callable which blacklists such modules.
+
+ Args:
+ root: A python object with which to start the traversal.
+ visit: A function taking arguments `(path, parent, children)`. Will be
+ called for each object found in the traversal.
+ """
+ _traverse_internal(root, visit, [], '')
diff --git a/tensorflow/tools/common/traverse_test.py b/tensorflow/tools/common/traverse_test.py
new file mode 100644
index 0000000000..eb195ec18e
--- /dev/null
+++ b/tensorflow/tools/common/traverse_test.py
@@ -0,0 +1,84 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Python module traversal."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from tensorflow.python.platform import googletest
+from tensorflow.tools.common import traverse
+
+
+class TestVisitor(object):
+
+ def __init__(self):
+ self.call_log = []
+
+ def __call__(self, path, parent, children):
+ # Do not traverse googletest, it's very deep.
+ for item in list(children):
+ if item[1] is googletest:
+ children.remove(item)
+ self.call_log += [(path, parent, children)]
+
+
+class TraverseTest(googletest.TestCase):
+
+ def test_cycle(self):
+
+ class Cyclist(object):
+ pass
+ Cyclist.cycle = Cyclist
+
+ visitor = TestVisitor()
+ traverse.traverse(Cyclist, visitor)
+ # We simply want to make sure we terminate.
+
+ def test_module(self):
+ visitor = TestVisitor()
+ traverse.traverse(sys.modules[__name__], visitor)
+
+ called = [parent for _, parent, _ in visitor.call_log]
+
+ self.assertIn(TestVisitor, called)
+ self.assertIn(TraverseTest, called)
+ self.assertIn(traverse, called)
+
+ def test_class(self):
+ visitor = TestVisitor()
+ traverse.traverse(TestVisitor, visitor)
+ self.assertEqual(TestVisitor,
+ visitor.call_log[0][1])
+ # There are a bunch of other members, but make sure that the ones we know
+ # about are there.
+ self.assertIn('__init__', [name for name, _ in visitor.call_log[0][2]])
+ self.assertIn('__call__', [name for name, _ in visitor.call_log[0][2]])
+
+ # There are more classes descended into, at least __class__ and
+ # __class__.__base__, neither of which are interesting to us, and which may
+ # change as part of Python version etc., so we don't test for them.
+
+ def test_non_class(self):
+ integer = 5
+ visitor = TestVisitor()
+ traverse.traverse(integer, visitor)
+ self.assertEqual([], visitor.call_log)
+
+
+if __name__ == '__main__':
+ googletest.main()