diff options
-rw-r--r-- | tensorflow/python/framework/importer_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_util.py | 5 |
2 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index c57b7d47b8..38ed539a4b 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -342,6 +342,8 @@ class ImportGraphDefTest(test.TestCase): graph = ops.Graph() with graph.as_default(): r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0]) + # Add an op that consumes the while loop output. + math_ops.add(r, 1) graph_def = graph.as_graph_def() # Import the GraphDef and make sure it runs. diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py index bead7c05f6..247c9f7299 100644 --- a/tensorflow/python/ops/control_flow_util.py +++ b/tensorflow/python/ops/control_flow_util.py @@ -88,7 +88,10 @@ def GetLoopConstantEnter(value): def GetOutputContext(op): """Return the control flow context for the output of an op.""" ctxt = op._get_control_flow_context() # pylint: disable=protected-access - if IsLoopExit(op): + # Exit nodes usually have a control flow context, except in the case where the + # exit node was imported via import_graph_def (in which case no nodes have + # control flow contexts). + if ctxt is not None and IsLoopExit(op): ctxt = ctxt.outer_context return ctxt |