diff options
author | 2018-01-03 16:08:49 -0800 | |
---|---|---|
committer | 2018-01-03 16:14:33 -0800 | |
commit | 9ad5d74b3bc869260b6c65e6152b266a4e392123 (patch) | |
tree | dd15305b652e67dc83c30113513cbe54d029d401 | |
parent | 9f816c90b621b59286a3b39faf213384d0563401 (diff) |
Fix bug with imported while loops with C API enabled.
Specifically, make control_flow_util.GetOutputContext robust to
imported exit nodes (which don't have control flow contexts). This was
a bug prior to the C API being enabled in that imported exit nodes
would not have contexts, but it happened to not be exposed. Note that
importing a metagraph will add the contexts back after doing the
initial import, but there's still a window where no contexts are
assigned.
PiperOrigin-RevId: 180730785
-rw-r--r-- | tensorflow/python/framework/importer_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_util.py | 5 |
2 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index c57b7d47b8..38ed539a4b 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -342,6 +342,8 @@ class ImportGraphDefTest(test.TestCase): graph = ops.Graph() with graph.as_default(): r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0]) + # Add an op that consumes the while loop output. + math_ops.add(r, 1) graph_def = graph.as_graph_def() # Import the GraphDef and make sure it runs. diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py index bead7c05f6..247c9f7299 100644 --- a/tensorflow/python/ops/control_flow_util.py +++ b/tensorflow/python/ops/control_flow_util.py @@ -88,7 +88,10 @@ def GetLoopConstantEnter(value): def GetOutputContext(op): """Return the control flow context for the output of an op.""" ctxt = op._get_control_flow_context() # pylint: disable=protected-access - if IsLoopExit(op): + # Exit nodes usually have a control flow context, except in the case where the + # exit node was imported via import_graph_def (in which case no nodes have + # control flow contexts). + if ctxt is not None and IsLoopExit(op): ctxt = ctxt.outer_context return ctxt |