aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/importer.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-14 09:20:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 09:24:24 -0800
commit2cdd76e65c37eda1cac7547ab3d32d8f19f5fef4 (patch)
tree3f0f82f3d62ef11364a2a92bfa81a185629d1f4e /tensorflow/python/framework/importer.py
parent934066da90cce263fba7f2049a455a070f0595e6 (diff)
Refactor Graph._create_op_from_tf_operation to not depend on the op's inputs
Previously, importer._ProcessNewOps may have call _create_op_from_tf_operation on a newly-imported op before _create_op_from_tf_operation has been called on all its inputs. This would fail since _create_op_from_tf_operation contained calls to Operation.inputs (some indirectly, e.g. through Operation.__init__). This change factors out the _create_op_from_tf_operation and Operation.__init__ logic requiring the inputs, and creates a new Graph._add_new_tf_operations method that creates all the Operation and then applies the factored-out logic. This also removes ImportGraphDefTest.TestCyclic and replaces it with a new test, testWhileLoop. The current Python implementation of import_graph_def allows any cycle to be imported. However, with the C API enabled, while loops are the only legal cycles. This test exposes this case since not all inputs can be available when creating the Operations forming the while loop cycle (it wasn't exposed before since the C++ ImportGraphDef function create nodes in topological order, although this isn't part of the API contract). PiperOrigin-RevId: 179052930
Diffstat (limited to 'tensorflow/python/framework/importer.py')
-rw-r--r--tensorflow/python/framework/importer.py8
1 files changed, 1 insertions, 7 deletions
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index d74fb25bb3..33c966ad88 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -278,8 +278,6 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
compat.as_str(name))
- # TODO(skyewm): control dependencies
-
def _ProcessNewOps(graph):
"""Processes the newly-added TF_Operations in `graph`."""
@@ -287,11 +285,7 @@ def _ProcessNewOps(graph):
# is specified in the attributes.
colocation_pairs = {}
- for c_op in c_api_util.new_tf_operations(graph):
- # pylint: disable=protected-access
- new_op = graph._create_op_from_tf_operation(c_op, compute_device=False)
- # pylint: enable=protected-access
-
+ for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access
colocation_names = _GetColocationNames(new_op)
if colocation_names:
colocation_pairs[new_op] = colocation_names