aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2017-02-13 20:45:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-13 21:05:19 -0800
commitd96c3b7d4089ca49892ec5000ddb6e1a3b90c6c3 (patch)
tree9f4dcdc27c0ea11c4286ae76cfb4a357df5341c4 /tensorflow/contrib/graph_editor
parent4a215b750b10db92906cae66dec30f139a184e02 (diff)
Fix documentation and guide for graph_editor
-Seal and expose reroute -Remove unavailable aliases from guide Change: 147429891
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/reroute.py54
1 files changed, 29 insertions, 25 deletions
diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py
index 4c5f281bad..c14bcac3be 100644
--- a/tensorflow/contrib/graph_editor/reroute.py
+++ b/tensorflow/contrib/graph_editor/reroute.py
@@ -18,11 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.graph_editor import subgraph
-from tensorflow.contrib.graph_editor import util
-from tensorflow.python.framework import ops as tf_ops
+from tensorflow.contrib.graph_editor import subgraph as _subgraph
+from tensorflow.contrib.graph_editor import util as _util
+from tensorflow.python.framework import ops as _tf_ops
-__all__ = [
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
"swap_ts",
"reroute_ts",
"swap_inputs",
@@ -46,8 +48,8 @@ def _check_ts_compatibility(ts0, ts1):
ValueError: if any pair of tensors (same index in ts0 and ts1) have
a dtype or a shape which is not compatible.
"""
- ts0 = util.make_list_of_t(ts0)
- ts1 = util.make_list_of_t(ts1)
+ ts0 = _util.make_list_of_t(ts0)
+ ts1 = _util.make_list_of_t(ts1)
if len(ts0) != len(ts1):
raise ValueError("ts0 and ts1 have different sizes: {} != {}".format(
len(ts0), len(ts1)))
@@ -176,13 +178,13 @@ def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None):
converted to a list of `tf.Operation`.
"""
a2b, b2a = _RerouteMode.check(mode)
- ts0 = util.make_list_of_t(ts0)
- ts1 = util.make_list_of_t(ts1)
+ ts0 = _util.make_list_of_t(ts0)
+ ts1 = _util.make_list_of_t(ts1)
_check_ts_compatibility(ts0, ts1)
if cannot_modify is not None:
- cannot_modify = frozenset(util.make_list_of_op(cannot_modify))
+ cannot_modify = frozenset(_util.make_list_of_op(cannot_modify))
if can_modify is not None:
- can_modify = frozenset(util.make_list_of_op(can_modify))
+ can_modify = frozenset(_util.make_list_of_op(can_modify))
nb_update_inputs = 0
precomputed_consumers = []
# precompute consumers to avoid issue with repeated tensors:
@@ -268,11 +270,11 @@ def _reroute_sgv_remap(sgv0, sgv1, mode):
ValueError: if sgv0 and sgv1 do not belong to the same graph.
"""
a2b, b2a = _RerouteMode.check(mode)
- if not isinstance(sgv0, subgraph.SubGraphView):
+ if not isinstance(sgv0, _subgraph.SubGraphView):
raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0)))
- if not isinstance(sgv1, subgraph.SubGraphView):
+ if not isinstance(sgv1, _subgraph.SubGraphView):
raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1)))
- util.check_graphs(sgv0, sgv1)
+ _util.check_graphs(sgv0, sgv1)
sgv0_ = sgv0.copy()
sgv1_ = sgv1.copy()
# pylint: disable=protected-access
@@ -327,13 +329,13 @@ def _reroute_sgv_inputs(sgv0, sgv1, mode):
StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
"""
- sgv0 = subgraph.make_view(sgv0)
- sgv1 = subgraph.make_view(sgv1)
- util.check_graphs(sgv0, sgv1)
+ sgv0 = _subgraph.make_view(sgv0)
+ sgv1 = _subgraph.make_view(sgv1)
+ _util.check_graphs(sgv0, sgv1)
can_modify = sgv0.ops + sgv1.ops
# also allow consumers of passthrough to be modified:
- can_modify += util.get_consuming_ops(sgv0.passthroughs)
- can_modify += util.get_consuming_ops(sgv1.passthroughs)
+ can_modify += _util.get_consuming_ops(sgv0.passthroughs)
+ can_modify += _util.get_consuming_ops(sgv1.passthroughs)
_reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify)
_reroute_sgv_remap(sgv0, sgv1, mode)
return sgv0, sgv1
@@ -357,9 +359,9 @@ def _reroute_sgv_outputs(sgv0, sgv1, mode):
StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
"""
- sgv0 = subgraph.make_view(sgv0)
- sgv1 = subgraph.make_view(sgv1)
- util.check_graphs(sgv0, sgv1)
+ sgv0 = _subgraph.make_view(sgv0)
+ sgv1 = _subgraph.make_view(sgv1)
+ _util.check_graphs(sgv0, sgv1)
cannot_modify = sgv0.ops + sgv1.ops
_reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify)
return sgv0, sgv1
@@ -432,9 +434,9 @@ def remove_control_inputs(op, cops):
TypeError: if op is not a `tf.Operation`.
ValueError: if any cop in cops is not a control input of op.
"""
- if not isinstance(op, tf_ops.Operation):
+ if not isinstance(op, _tf_ops.Operation):
raise TypeError("Expected a tf.Operation, got: {}", type(op))
- cops = util.make_list_of_op(cops, allow_graph=False)
+ cops = _util.make_list_of_op(cops, allow_graph=False)
for cop in cops:
if cop not in op.control_inputs:
raise ValueError("{} is not a control_input of {}".format(op.name,
@@ -457,9 +459,9 @@ def add_control_inputs(op, cops):
TypeError: if op is not a tf.Operation
ValueError: if any cop in cops is already a control input of op.
"""
- if not isinstance(op, tf_ops.Operation):
+ if not isinstance(op, _tf_ops.Operation):
raise TypeError("Expected a tf.Operation, got: {}", type(op))
- cops = util.make_list_of_op(cops, allow_graph=False)
+ cops = _util.make_list_of_op(cops, allow_graph=False)
for cop in cops:
if cop in op.control_inputs:
raise ValueError("{} is already a control_input of {}".format(op.name,
@@ -468,3 +470,5 @@ def add_control_inputs(op, cops):
op._control_inputs += cops
op._recompute_node_def()
# pylint: enable=protected-access
+
+remove_undocumented(__name__, _allowed_symbols)