diff options
author | 2018-01-16 16:17:36 -0800 | |
---|---|---|
committer | 2018-01-16 16:23:05 -0800 | |
commit | c9096fd166a9d7fdb62c6cb747a74edb73630b0c (patch) | |
tree | 4c9cd12946750b03b2dc850916fe2e16db3d955e /tensorflow/python/ops/control_flow_util.py | |
parent | 1de8ca3edb22c232b6cd4a87076bd5e0a7f6b86f (diff) |
[TF] Fix XLA Control Flow gradient stacks max_size creation.
Stack creation uses tf.while_loop's maximum_iterations iff the while_loop
was created inside an XLA/TPU context. Added several error checks to ensure
this provides useful error messages if the limited use case is not supported.
PiperOrigin-RevId: 182128135
Diffstat (limited to 'tensorflow/python/ops/control_flow_util.py')
-rw-r--r-- | tensorflow/python/ops/control_flow_util.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py index 247c9f7299..eee31102db 100644 --- a/tensorflow/python/ops/control_flow_util.py +++ b/tensorflow/python/ops/control_flow_util.py @@ -96,7 +96,7 @@ def GetOutputContext(op): return ctxt -def GetContainingWhileContext(ctxt): +def GetContainingWhileContext(ctxt, stop_ctxt=None): """Returns the first ancestor WhileContext of `ctxt`. Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a @@ -104,13 +104,16 @@ def GetContainingWhileContext(ctxt): Args: ctxt: ControlFlowContext + stop_ctxt: ControlFlowContext, optional. If provided, the search will end + if it sees stop_ctxt. Returns: `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing - `ctxt`, or None if `ctxt` is not in a while loop. + `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not + `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. """ while ctxt: - if ctxt.IsWhileContext(): return ctxt + if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt ctxt = ctxt.outer_context return None |