aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py48
1 files changed, 20 insertions, 28 deletions
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index 195ad11c71..c9aa4d4889 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -282,9 +282,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
as is.
2. Tensors in the forward pass graph. These tensors may not be "live"
when the gradient is being computed. We replace such references by their
- corresponding tensor in the least common ancestor graph of `grad_graph` and
- `cond_graph`. Since we export intermediate tensors for all branch
- functions, this is always possible.
+ corresponding tensor in `cond_graph.outer_graph`. In the case of nested
+ control flow or functions, the gradient logic handling
+ `grad_graph.outer_graph` will make sure the tensor from
+ `cond_graph.outer_graph` is also correctly captured.
Args:
cond_graph: function.FuncGraph. The forward-pass function.
@@ -296,24 +297,23 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
new_inputs = []
for t in grad_graph.external_captures:
+ # `t` must either be in `grad_graph.outer_graph` or in the forward
+ # `cond_graph`.
if t.graph != grad_graph.outer_graph:
- # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
- # tensor to the least common ancestor of the `cond_graph` and
- # `grad_graph` so that it is "in-scope" for `grad_graph`.
- # TODO(srbs): `_is_ancestor` calls may be expensive. Compute the least
- # common ancestor once and re-use.
- assert _is_ancestor(cond_graph, t.graph)
- while not _is_ancestor(grad_graph, t.graph):
- assert isinstance(t.graph, _function.FuncGraph)
- if t in t.graph.internal_captures:
- # TODO(srbs): Consider building a map of internal_captures ->
- # external_captures instead of searching for `t` twice.
- t = t.graph.external_captures[t.graph.internal_captures.index(t)]
- else:
- # Note: All intermediate tensors are output by the If op.
- # TODO(srbs): .index() calls may be expensive. Optimize.
- t = t.graph._if.outputs[t.graph.outputs.index(t)]
- assert _is_ancestor(grad_graph, t.graph)
+ assert t.graph == cond_graph
+ # `internal_captures` are not treated as intermediates and hence not added
+ # to If op outputs. So we get the outer tensor corresponding to those
+ # from the list of `external_captures`.
+ try:
+ t = t.graph._if.outputs[t.graph.outputs.index(t)]
+ except ValueError:
+ index = t.graph.internal_captures.index(t)
+ t = t.graph.external_captures[index]
+
+ # Note: We rely on the capturing logic of the gradient If op graph to
+ # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
+ # and while_v2 handle this while building their gradient functions.
+ assert t.graph == cond_graph.outer_graph
new_inputs.append(t)
return new_inputs
@@ -492,11 +492,3 @@ def _get_output_shapes(true_graph_outputs, false_graph_outputs):
for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
]
return output_shapes
-
-
-def _is_ancestor(graph, maybe_ancestor):
- if maybe_ancestor == graph:
- return True
- if isinstance(graph, _function.FuncGraph):
- return _is_ancestor(graph.outer_graph, maybe_ancestor)
- return False