aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/control_flow_util.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-01-16 16:17:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-16 16:23:05 -0800
commitc9096fd166a9d7fdb62c6cb747a74edb73630b0c (patch)
tree4c9cd12946750b03b2dc850916fe2e16db3d955e /tensorflow/python/ops/control_flow_util.py
parent1de8ca3edb22c232b6cd4a87076bd5e0a7f6b86f (diff)
[TF] Fix XLA Control Flow gradient stacks max_size creation.
Stack creation uses tf.while_loop's maximum_iterations iff the while_loop was created inside an XLA/TPU context. Added several error checks to ensure this provides useful error messages if the limited use case is not supported. PiperOrigin-RevId: 182128135
Diffstat (limited to 'tensorflow/python/ops/control_flow_util.py')
-rw-r--r--tensorflow/python/ops/control_flow_util.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py
index 247c9f7299..eee31102db 100644
--- a/tensorflow/python/ops/control_flow_util.py
+++ b/tensorflow/python/ops/control_flow_util.py
@@ -96,7 +96,7 @@ def GetOutputContext(op):
return ctxt
-def GetContainingWhileContext(ctxt):
+def GetContainingWhileContext(ctxt, stop_ctxt=None):
"""Returns the first ancestor WhileContext of `ctxt`.
Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
@@ -104,13 +104,16 @@ def GetContainingWhileContext(ctxt):
Args:
ctxt: ControlFlowContext
+ stop_ctxt: ControlFlowContext, optional. If provided, the search will end
+ if it sees stop_ctxt.
Returns:
`ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
- `ctxt`, or None if `ctxt` is not in a while loop.
+ `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not
+ `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
"""
while ctxt:
- if ctxt.IsWhileContext(): return ctxt
+ if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
ctxt = ctxt.outer_context
return None