diff options
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index faebdc3780..75d07454b3 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -2065,21 +2065,25 @@ def cond(pred, # Build the graph for the true branch in a new context. context_t = CondContext(pred, pivot_1, branch=1) - context_t.Enter() - orig_res_t, res_t = context_t.BuildCondBranch(true_fn) - if orig_res_t is None: - raise ValueError("true_fn must have a return value.") - context_t.ExitResult(res_t) - context_t.Exit() + try: + context_t.Enter() + orig_res_t, res_t = context_t.BuildCondBranch(true_fn) + if orig_res_t is None: + raise ValueError("true_fn must have a return value.") + context_t.ExitResult(res_t) + finally: + context_t.Exit() # Build the graph for the false branch in a new context. context_f = CondContext(pred, pivot_2, branch=0) - context_f.Enter() - orig_res_f, res_f = context_f.BuildCondBranch(false_fn) - if orig_res_f is None: - raise ValueError("false_fn must have a return value.") - context_f.ExitResult(res_f) - context_f.Exit() + try: + context_f.Enter() + orig_res_f, res_f = context_f.BuildCondBranch(false_fn) + if orig_res_f is None: + raise ValueError("false_fn must have a return value.") + context_f.ExitResult(res_f) + finally: + context_f.Exit() if not strict: orig_res_t = _UnpackIfSingleton(orig_res_t) |