diff options
author | 2018-08-10 18:26:22 -0700 | |
---|---|---|
committer | 2018-08-10 18:26:22 -0700 | |
commit | 1432f3433cb1c10ce3570f62a67f12a1f569ac3d (patch) | |
tree | 911dc526afc0ecca7cd172230e29ba11408a4d69 | |
parent | 3e2e407c49aa14b525ff6c37538ea4506a152798 (diff) | |
parent | 54fee8b09ed3a3b6f87dcd76dc6fb7c388e6482f (diff) |
Merge pull request #20479 from naurril:bug-fix
PiperOrigin-RevId: 208300149
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index faebdc3780..75d07454b3 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -2065,21 +2065,25 @@ def cond(pred, # Build the graph for the true branch in a new context. context_t = CondContext(pred, pivot_1, branch=1) - context_t.Enter() - orig_res_t, res_t = context_t.BuildCondBranch(true_fn) - if orig_res_t is None: - raise ValueError("true_fn must have a return value.") - context_t.ExitResult(res_t) - context_t.Exit() + try: + context_t.Enter() + orig_res_t, res_t = context_t.BuildCondBranch(true_fn) + if orig_res_t is None: + raise ValueError("true_fn must have a return value.") + context_t.ExitResult(res_t) + finally: + context_t.Exit() # Build the graph for the false branch in a new context. context_f = CondContext(pred, pivot_2, branch=0) - context_f.Enter() - orig_res_f, res_f = context_f.BuildCondBranch(false_fn) - if orig_res_f is None: - raise ValueError("false_fn must have a return value.") - context_f.ExitResult(res_f) - context_f.Exit() + try: + context_f.Enter() + orig_res_f, res_f = context_f.BuildCondBranch(false_fn) + if orig_res_f is None: + raise ValueError("false_fn must have a return value.") + context_f.ExitResult(res_f) + finally: + context_f.Exit() if not strict: orig_res_t = _UnpackIfSingleton(orig_res_t) |