aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/ops/control_flow_ops.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index faebdc3780..75d07454b3 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -2065,21 +2065,25 @@ def cond(pred,
# Build the graph for the true branch in a new context.
context_t = CondContext(pred, pivot_1, branch=1)
- context_t.Enter()
- orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
- if orig_res_t is None:
- raise ValueError("true_fn must have a return value.")
- context_t.ExitResult(res_t)
- context_t.Exit()
+ try:
+ context_t.Enter()
+ orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
+ if orig_res_t is None:
+ raise ValueError("true_fn must have a return value.")
+ context_t.ExitResult(res_t)
+ finally:
+ context_t.Exit()
# Build the graph for the false branch in a new context.
context_f = CondContext(pred, pivot_2, branch=0)
- context_f.Enter()
- orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
- if orig_res_f is None:
- raise ValueError("false_fn must have a return value.")
- context_f.ExitResult(res_f)
- context_f.Exit()
+ try:
+ context_f.Enter()
+ orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
+ if orig_res_f is None:
+ raise ValueError("false_fn must have a return value.")
+ context_f.ExitResult(res_f)
+ finally:
+ context_f.Exit()
if not strict:
orig_res_t = _UnpackIfSingleton(orig_res_t)