diff options
author | 2016-08-05 11:36:41 -0800 | |
---|---|---|
committer | 2016-08-05 12:47:34 -0700 | |
commit | 40e6aa7c2a2074230cf3942b9c5409a5c4641840 (patch) | |
tree | 15d19b69a0d373cbc93a91979761be7631aadd88 /tensorflow/python/framework/gen_docs_combined.py | |
parent | cac1c40022e74ad3fc3dba1abab22f31f3efd2c2 (diff) |
Refactor/consolidate internal/external gen_docs_combined
Change: 129470667
Diffstat (limited to 'tensorflow/python/framework/gen_docs_combined.py')
-rw-r--r-- | tensorflow/python/framework/gen_docs_combined.py | 93 |
1 files changed, 60 insertions, 33 deletions
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 6355730210..be0fc992cd 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os.path import sys @@ -43,33 +44,60 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by """ -def get_module_to_name(): - return { - tf: "tf", - tf.errors: "tf.errors", - tf.image: "tf.image", - tf.nn: "tf.nn", - tf.nn.rnn_cell: "tf.nn.rnn_cell", - tf.train: "tf.train", - tf.python_io: "tf.python_io", - tf.summary: "tf.summary", - tf.test: "tf.test", - tf.contrib.bayesflow.stochastic_graph: ( - "tf.contrib.bayesflow.stochastic_graph"), - tf.contrib.copy_graph: "tf.contrib.copy_graph", - tf.contrib.distributions: "tf.contrib.distributions", - tf.contrib.ffmpeg: "tf.contrib.ffmpeg", - tf.contrib.framework: "tf.contrib.framework", - tf.contrib.graph_editor: "tf.contrib.graph_editor", - tf.contrib.layers: "tf.contrib.layers", - tf.contrib.learn: "tf.contrib.learn", - tf.contrib.learn.monitors: ( - "tf.contrib.learn.monitors"), - tf.contrib.losses: "tf.contrib.losses", - tf.contrib.rnn: "tf.contrib.rnn", - tf.contrib.metrics: "tf.contrib.metrics", - tf.contrib.util: "tf.contrib.util", - } +def module_names(): + return [ + "tf", + "tf.errors", + "tf.image", + "tf.nn", + "tf.nn.rnn_cell", + "tf.train", + "tf.python_io", + "tf.summary", + "tf.test", + "tf.contrib.bayesflow.stochastic_graph", + "tf.contrib.copy_graph", + "tf.contrib.distributions", + "tf.contrib.ffmpeg", + "tf.contrib.framework", + "tf.contrib.graph_editor", + "tf.contrib.layers", + "tf.contrib.learn", + "tf.contrib.learn.monitors", + "tf.contrib.losses", + "tf.contrib.rnn", + "tf.contrib.metrics", + "tf.contrib.util", + ] + + +def find_module(base_module, name): + if name == "tf": + return base_module + # Special case for ffmpeg is needed since it's not linked in by default due + # to size concerns. + elif name == "tf.contrib.ffmpeg": + return ffmpeg + elif name.startswith("tf."): + subname = name[3:] + subnames = subname.split(".") + parent_module = base_module + for s in subnames: + if not hasattr(parent_module, s): + raise ValueError( + "Module not found: {}. Submodule {} not found in parent module {}." + " Possible candidates are {}".format( + name, s, parent_module.__name__, dir(parent_module))) + parent_module = getattr(parent_module, s) + return parent_module + else: + raise ValueError( + "Invalid module name: {}. Module names must start with 'tf.'".format( + name)) + + +def get_module_to_name(names): + return collections.OrderedDict([(find_module(tf, x), x) for x in names]) def all_libraries(module_to_name, members, documented): @@ -85,15 +113,14 @@ def all_libraries(module_to_name, members, documented): """ def library(name, title, module=None, **args): if module is None: - module = sys.modules["tensorflow.python.ops" + - ("" if name == "ops" else "." + name)] + module = sys.modules["tensorflow.python.ops." + name] return (name + ".md", docs.Library(title=title, module_to_name=module_to_name, members=members, documented=documented, module=module, **args)) - return [ + return collections.OrderedDict([ # Splits of module 'tf'. library("framework", "Building Graphs", framework_lib), library("check_ops", "Asserts and boolean checks."), @@ -180,7 +207,7 @@ def all_libraries(module_to_name, members, documented): library("contrib.util", "Utilities (contrib)", tf.contrib.util), library("contrib.copy_graph", "Copying Graph Elements (contrib)", tf.contrib.copy_graph), - ] + ]) _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange", "HistogramProto", "ConfigProto", "NodeDef", "GraphDef", @@ -205,9 +232,9 @@ def main(unused_argv): # Document libraries documented = set() - module_to_name = get_module_to_name() + module_to_name = get_module_to_name(module_names()) members = docs.collect_members(module_to_name, exclude=EXCLUDE) - libraries = all_libraries(module_to_name, members, documented) + libraries = all_libraries(module_to_name, members, documented).items() # Define catch_all library before calling write_libraries to avoid complaining # about generically hidden symbols. |