diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-12-14 09:20:51 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-14 09:24:24 -0800 |
commit | 2cdd76e65c37eda1cac7547ab3d32d8f19f5fef4 (patch) | |
tree | 3f0f82f3d62ef11364a2a92bfa81a185629d1f4e /tensorflow/python/framework/importer.py | |
parent | 934066da90cce263fba7f2049a455a070f0595e6 (diff) |
Refactor Graph._create_op_from_tf_operation to not depend on the op's inputs
Previously, importer._ProcessNewOps may have call
_create_op_from_tf_operation on a newly-imported op before
_create_op_from_tf_operation has been called on all its inputs. This
would fail since _create_op_from_tf_operation contained calls to
Operation.inputs (some indirectly, e.g. through Operation.__init__).
This change factors out the _create_op_from_tf_operation and
Operation.__init__ logic requiring the inputs, and creates a new
Graph._add_new_tf_operations method that creates all the Operation and
then applies the factored-out logic.
This also removes ImportGraphDefTest.TestCyclic and replaces it with a
new test, testWhileLoop. The current Python implementation of
import_graph_def allows any cycle to be imported. However, with the C
API enabled, while loops are the only legal cycles. This test exposes
this case since not all inputs can be available when creating the
Operations forming the while loop cycle (it wasn't exposed before
since the C++ ImportGraphDef function create nodes in topological
order, although this isn't part of the API contract).
PiperOrigin-RevId: 179052930
Diffstat (limited to 'tensorflow/python/framework/importer.py')
-rw-r--r-- | tensorflow/python/framework/importer.py | 8 |
1 files changed, 1 insertions, 7 deletions
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index d74fb25bb3..33c966ad88 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -278,8 +278,6 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map, c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, compat.as_str(name)) - # TODO(skyewm): control dependencies - def _ProcessNewOps(graph): """Processes the newly-added TF_Operations in `graph`.""" @@ -287,11 +285,7 @@ def _ProcessNewOps(graph): # is specified in the attributes. colocation_pairs = {} - for c_op in c_api_util.new_tf_operations(graph): - # pylint: disable=protected-access - new_op = graph._create_op_from_tf_operation(c_op, compute_device=False) - # pylint: enable=protected-access - + for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access colocation_names = _GetColocationNames(new_op) if colocation_names: colocation_pairs[new_op] = colocation_names |