aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-01-03 16:08:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-03 16:14:33 -0800
commit9ad5d74b3bc869260b6c65e6152b266a4e392123 (patch)
treedd15305b652e67dc83c30113513cbe54d029d401
parent9f816c90b621b59286a3b39faf213384d0563401 (diff)
Fix bug with imported while loops with C API enabled.
Specifically, make control_flow_util.GetOutputContext robust to imported exit nodes (which don't have control flow contexts). This was a bug prior to the C API being enabled in that imported exit nodes would not have contexts, but it happened to not be exposed. Note that importing a metagraph will add the contexts back after doing the initial import, but there's still a window where no contexts are assigned. PiperOrigin-RevId: 180730785
-rw-r--r--tensorflow/python/framework/importer_test.py2
-rw-r--r--tensorflow/python/ops/control_flow_util.py5
2 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index c57b7d47b8..38ed539a4b 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -342,6 +342,8 @@ class ImportGraphDefTest(test.TestCase):
graph = ops.Graph()
with graph.as_default():
r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0])
+ # Add an op that consumes the while loop output.
+ math_ops.add(r, 1)
graph_def = graph.as_graph_def()
# Import the GraphDef and make sure it runs.
diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py
index bead7c05f6..247c9f7299 100644
--- a/tensorflow/python/ops/control_flow_util.py
+++ b/tensorflow/python/ops/control_flow_util.py
@@ -88,7 +88,10 @@ def GetLoopConstantEnter(value):
def GetOutputContext(op):
"""Return the control flow context for the output of an op."""
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
- if IsLoopExit(op):
+ # Exit nodes usually have a control flow context, except in the case where the
+ # exit node was imported via import_graph_def (in which case no nodes have
+ # control flow contexts).
+ if ctxt is not None and IsLoopExit(op):
ctxt = ctxt.outer_context
return ctxt