diff options
author | Martin Wicke <wicke@google.com> | 2017-01-08 01:05:21 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-08 01:18:42 -0800 |
commit | 2691c1260f792e85a40f598809e84aa60d19e227 (patch) | |
tree | 77aceea44eda407ee4006d38e38d15d90b6cf89d /tensorflow/tools/common | |
parent | 367078ab8c0ca8feab1ca88e871279df150639d5 (diff) |
Fix a bad comment.
Change: 143889993
Diffstat (limited to 'tensorflow/tools/common')
-rw-r--r-- | tensorflow/tools/common/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/tools/common/public_api.py | 78 | ||||
-rw-r--r-- | tensorflow/tools/common/public_api_test.py | 68 |
3 files changed, 162 insertions, 0 deletions
diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index f1d43134b8..96ae9583d7 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -10,6 +10,22 @@ package( ) py_library( + name = "public_api", + srcs = ["public_api.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "public_api_test", + srcs = ["public_api_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":public_api", + "//tensorflow/python:platform_test", + ], +) + +py_library( name = "traverse", srcs = ["traverse.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py new file mode 100644 index 0000000000..5d70cb7b76 --- /dev/null +++ b/tensorflow/tools/common/public_api.py @@ -0,0 +1,78 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Visitor restricting traversal to only the public tensorflow API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect + + +class PublicAPIVisitor(object): + """Visitor to use with `traverse` to visit exactly the public TF API.""" + + def __init__(self, visitor): + """Constructor. + + `visitor` should be a callable suitable as a visitor for `traverse`. It will + be called only for members of the public TensorFlow API. + + Args: + visitor: A visitor to call for the public API. + """ + self._visitor = visitor + + # Modules/classes we do not want to descend into if we hit them. Usually, + # sytem modules exposed through platforms for compatibility reasons. + # Each entry maps a module path to a name to ignore in traversal. + _do_not_descend_map = { + # TODO(drpng): This can be removed once sealed off. + '': ['platform', 'pywrap_tensorflow'], + + # Some implementations have this internal module that we shouldn't expose. + 'flags': ['cpp_flags'], + + # Everything below here is legitimate. + 'app': 'flags', # It'll stay, but it's not officially part of the API + 'test': ['mock'], # Imported for compatibility between py2/3. + } + + def _isprivate(self, name): + """Return whether a name is private.""" + return name.startswith('_') + + def _do_not_descend(self, path, name): + """Safely queries if a specific fully qualified name should be excluded.""" + return (path in self._do_not_descend_map and + name in self._do_not_descend_map[path]) + + def __call__(self, path, parent, children): + """Visitor interface, see `traverse` for details.""" + if inspect.ismodule(parent) and len(path.split('.')) > 10: + raise RuntimeError('Modules nested too deep:\n%s\n\nThis is likely a ' + 'problem with an accidental public import.' % path) + + # Remove things that are not visible. + for name, child in list(children): + if self._isprivate(name): + children.remove((name, child)) + + self._visitor(path, parent, children) + + # Remove things that are visible, but which should not be descended into. + for name, child in list(children): + if self._do_not_descend(path, name): + children.remove((name, child)) diff --git a/tensorflow/tools/common/public_api_test.py b/tensorflow/tools/common/public_api_test.py new file mode 100644 index 0000000000..93a3bcc274 --- /dev/null +++ b/tensorflow/tools/common/public_api_test.py @@ -0,0 +1,68 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.tools.common.public_api.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import googletest +from tensorflow.tools.common import public_api + + +class PublicApiTest(googletest.TestCase): + + class TestVisitor(object): + + def __init__(self): + self.symbols = set() + self.last_parent = None + self.last_children = None + + def __call__(self, path, parent, children): + self.symbols.add(path) + self.last_parent = parent + self.last_children = list(children) # Make a copy to preserve state. + + def test_call_forward(self): + visitor = self.TestVisitor() + children = [('name1', 'thing1'), ('name2', 'thing2')] + public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) + self.assertEqual(set(['test']), visitor.symbols) + self.assertEqual('dummy', visitor.last_parent) + self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')], + visitor.last_children) + + def test_private_child_removal(self): + visitor = self.TestVisitor() + children = [('name1', 'thing1'), ('_name2', 'thing2')] + public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) + # Make sure the private symbols are removed before the visitor is called. + self.assertEqual([('name1', 'thing1')], visitor.last_children) + self.assertEqual([('name1', 'thing1')], children) + + def test_no_descent_child_removal(self): + visitor = self.TestVisitor() + children = [('name1', 'thing1'), ('mock', 'thing2')] + public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) + # Make sure not-to-be-descended-into symbols are removed after the visitor + # is called. + self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')], + visitor.last_children) + self.assertEqual([('name1', 'thing1')], children) + + +if __name__ == '__main__': + googletest.main() |