aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/common
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-09 14:30:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-09 14:34:40 -0700
commit0058c1f134c7de0733ffed754fe3a3091dea60ca (patch)
tree7c9527d542a8e7809dfa20d8762d1ec0ddd931fd /tensorflow/tools/common
parent7ce6e4f871b0767547a9a5cfb9d19dba79704489 (diff)
Make API do not descend map a bit more precise by including
the root module name. Add ability to mark symbols as private. PiperOrigin-RevId: 158563334
Diffstat (limited to 'tensorflow/tools/common')
-rw-r--r--tensorflow/tools/common/public_api.py54
1 files changed, 40 insertions, 14 deletions
diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py
index cab3b2ff6a..e0acead919 100644
--- a/tensorflow/tools/common/public_api.py
+++ b/tensorflow/tools/common/public_api.py
@@ -36,12 +36,20 @@ class PublicAPIVisitor(object):
visitor: A visitor to call for the public API.
"""
self._visitor = visitor
+ self._root_name = 'tf'
+
+ # Modules/classes we want to suppress entirely.
+ self._private_map = {
+ # Some implementations have this internal module that we shouldn't
+ # expose.
+ 'tf.flags': ['cpp_flags'],
+ }
# Modules/classes we do not want to descend into if we hit them. Usually,
# system modules exposed through platforms for compatibility reasons.
# Each entry maps a module path to a name to ignore in traversal.
self._do_not_descend_map = {
- '': [
+ 'tf': [
'core',
'examples',
'flags', # Don't add flags
@@ -56,18 +64,26 @@ class PublicAPIVisitor(object):
'tensorboard',
],
- # Some implementations have this internal module that we shouldn't
- # expose.
- 'flags': ['cpp_flags'],
-
## Everything below here is legitimate.
# It'll stay, but it's not officially part of the API.
- 'app': ['flags'],
+ 'tf.app': ['flags'],
# Imported for compatibility between py2/3.
- 'test': ['mock'],
+ 'tf.test': ['mock'],
}
@property
+ def private_map(self):
+ """A map from parents to symbols that should not be included at all.
+
+ This map can be edited, but it should not be edited once traversal has
+ begun.
+
+ Returns:
+ The map marking symbols to not include.
+ """
+ return self._private_map
+
+ @property
def do_not_descend_map(self):
"""A map from parents to symbols that should not be descended into.
@@ -79,11 +95,17 @@ class PublicAPIVisitor(object):
"""
return self._do_not_descend_map
- def _isprivate(self, name):
+ def set_root_name(self, root_name):
+ """Override the default root name of 'tf'."""
+ self._root_name = root_name
+
+ def _is_private(self, path, name):
"""Return whether a name is private."""
# TODO(wicke): Find out what names to exclude.
- return (name.startswith('_') and not re.match('__.*__$', name) or
- name in ['__base__', '__class__'])
+ return ((path in self._private_map and
+ name in self._private_map[path]) or
+ (name.startswith('_') and not re.match('__.*__$', name) or
+ name in ['__base__', '__class__']))
def _do_not_descend(self, path, name):
"""Safely queries if a specific fully qualified name should be excluded."""
@@ -95,17 +117,21 @@ class PublicAPIVisitor(object):
# Avoid long waits in cases of pretty unambiguous failure.
if tf_inspect.ismodule(parent) and len(path.split('.')) > 10:
- raise RuntimeError('Modules nested too deep:\n%s\n\nThis is likely a '
- 'problem with an accidental public import.' % path)
+ raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a '
+ 'problem with an accidental public import.' %
+ (self._root_name, path))
+
+ # Includes self._root_name
+ full_path = '.'.join([self._root_name, path]) if path else self._root_name
# Remove things that are not visible.
for name, child in list(children):
- if self._isprivate(name):
+ if self._is_private(full_path, name):
children.remove((name, child))
self._visitor(path, parent, children)
# Remove things that are visible, but which should not be descended into.
for name, child in list(children):
- if self._do_not_descend(path, name):
+ if self._do_not_descend(full_path, name):
children.remove((name, child))