aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/common
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2017-01-08 01:05:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-08 01:18:42 -0800
commit2691c1260f792e85a40f598809e84aa60d19e227 (patch)
tree77aceea44eda407ee4006d38e38d15d90b6cf89d /tensorflow/tools/common
parent367078ab8c0ca8feab1ca88e871279df150639d5 (diff)
Fix a bad comment.
Change: 143889993
Diffstat (limited to 'tensorflow/tools/common')
-rw-r--r--tensorflow/tools/common/BUILD16
-rw-r--r--tensorflow/tools/common/public_api.py78
-rw-r--r--tensorflow/tools/common/public_api_test.py68
3 files changed, 162 insertions, 0 deletions
diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD
index f1d43134b8..96ae9583d7 100644
--- a/tensorflow/tools/common/BUILD
+++ b/tensorflow/tools/common/BUILD
@@ -10,6 +10,22 @@ package(
)
py_library(
+ name = "public_api",
+ srcs = ["public_api.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "public_api_test",
+ srcs = ["public_api_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":public_api",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_library(
name = "traverse",
srcs = ["traverse.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py
new file mode 100644
index 0000000000..5d70cb7b76
--- /dev/null
+++ b/tensorflow/tools/common/public_api.py
@@ -0,0 +1,78 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Visitor restricting traversal to only the public tensorflow API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import inspect
+
+
+class PublicAPIVisitor(object):
+ """Visitor to use with `traverse` to visit exactly the public TF API."""
+
+ def __init__(self, visitor):
+ """Constructor.
+
+ `visitor` should be a callable suitable as a visitor for `traverse`. It will
+ be called only for members of the public TensorFlow API.
+
+ Args:
+ visitor: A visitor to call for the public API.
+ """
+ self._visitor = visitor
+
+ # Modules/classes we do not want to descend into if we hit them. Usually,
+ # sytem modules exposed through platforms for compatibility reasons.
+ # Each entry maps a module path to a name to ignore in traversal.
+ _do_not_descend_map = {
+ # TODO(drpng): This can be removed once sealed off.
+ '': ['platform', 'pywrap_tensorflow'],
+
+ # Some implementations have this internal module that we shouldn't expose.
+ 'flags': ['cpp_flags'],
+
+ # Everything below here is legitimate.
+ 'app': 'flags', # It'll stay, but it's not officially part of the API
+ 'test': ['mock'], # Imported for compatibility between py2/3.
+ }
+
+ def _isprivate(self, name):
+ """Return whether a name is private."""
+ return name.startswith('_')
+
+ def _do_not_descend(self, path, name):
+ """Safely queries if a specific fully qualified name should be excluded."""
+ return (path in self._do_not_descend_map and
+ name in self._do_not_descend_map[path])
+
+ def __call__(self, path, parent, children):
+ """Visitor interface, see `traverse` for details."""
+ if inspect.ismodule(parent) and len(path.split('.')) > 10:
+ raise RuntimeError('Modules nested too deep:\n%s\n\nThis is likely a '
+ 'problem with an accidental public import.' % path)
+
+ # Remove things that are not visible.
+ for name, child in list(children):
+ if self._isprivate(name):
+ children.remove((name, child))
+
+ self._visitor(path, parent, children)
+
+ # Remove things that are visible, but which should not be descended into.
+ for name, child in list(children):
+ if self._do_not_descend(path, name):
+ children.remove((name, child))
diff --git a/tensorflow/tools/common/public_api_test.py b/tensorflow/tools/common/public_api_test.py
new file mode 100644
index 0000000000..93a3bcc274
--- /dev/null
+++ b/tensorflow/tools/common/public_api_test.py
@@ -0,0 +1,68 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.tools.common.public_api."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.platform import googletest
+from tensorflow.tools.common import public_api
+
+
+class PublicApiTest(googletest.TestCase):
+
+ class TestVisitor(object):
+
+ def __init__(self):
+ self.symbols = set()
+ self.last_parent = None
+ self.last_children = None
+
+ def __call__(self, path, parent, children):
+ self.symbols.add(path)
+ self.last_parent = parent
+ self.last_children = list(children) # Make a copy to preserve state.
+
+ def test_call_forward(self):
+ visitor = self.TestVisitor()
+ children = [('name1', 'thing1'), ('name2', 'thing2')]
+ public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
+ self.assertEqual(set(['test']), visitor.symbols)
+ self.assertEqual('dummy', visitor.last_parent)
+ self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')],
+ visitor.last_children)
+
+ def test_private_child_removal(self):
+ visitor = self.TestVisitor()
+ children = [('name1', 'thing1'), ('_name2', 'thing2')]
+ public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
+ # Make sure the private symbols are removed before the visitor is called.
+ self.assertEqual([('name1', 'thing1')], visitor.last_children)
+ self.assertEqual([('name1', 'thing1')], children)
+
+ def test_no_descent_child_removal(self):
+ visitor = self.TestVisitor()
+ children = [('name1', 'thing1'), ('mock', 'thing2')]
+ public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
+ # Make sure not-to-be-descended-into symbols are removed after the visitor
+ # is called.
+ self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')],
+ visitor.last_children)
+ self.assertEqual([('name1', 'thing1')], children)
+
+
+if __name__ == '__main__':
+ googletest.main()