diff options
author | Saurabh Saxena <srbs@google.com> | 2018-10-05 17:34:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 17:39:58 -0700 |
commit | 213d76a6ed77a696883502c53a3a4f81d2ee4042 (patch) | |
tree | d701196115c416f23b6861621ce4df79eaee5262 /tensorflow/python | |
parent | 4831740f90eaf266a99d3ffa7d390d54325b689f (diff) |
Simply the logic for bubbling captured tensors when building cond_v2 grad.
The current logic tries to bubble the forward pass tensor to the outermost
graph. That might not always be do-able e.g. when the cond is inside a while loop
it will need to know accumulator logic for while_loop. So instead, the cond_grad
now captures tensors from the forward If op's graph. When the grad If op is
built these tensors will be appropriately captured by the surrounding FuncGraph.
PiperOrigin-RevId: 215993009
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/cond_v2_impl.py | 48 |
2 files changed, 22 insertions, 32 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 7fae5249aa..baea5c0f6d 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -661,8 +661,7 @@ class ControlFlowTest(test.TestCase): sess.run(r) def testCondGrad_1(self): - graph = ops.Graph() - with graph.as_default(): + with self.cached_session(): x = constant_op.constant(10.0, name="x") pred = math_ops.less(1, 2) fn1 = lambda: array_ops.identity(x) @@ -670,8 +669,7 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.cond(pred, fn1, fn2) grad = gradients_impl.gradients(r, [x])[0] - with self.cached_session(): - self.assertAllEqual(1.0, grad.eval()) + self.assertAllEqual(1.0, grad.eval()) def testCondGrad_2(self): with self.cached_session(): diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index 195ad11c71..c9aa4d4889 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -282,9 +282,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph): as is. 2. Tensors in the forward pass graph. These tensors may not be "live" when the gradient is being computed. We replace such references by their - corresponding tensor in the least common ancestor graph of `grad_graph` and - `cond_graph`. Since we export intermediate tensors for all branch - functions, this is always possible. + corresponding tensor in `cond_graph.outer_graph`. In the case of nested + control flow or functions, the gradient logic handling + `grad_graph.outer_graph` will make sure the tensor from + `cond_graph.outer_graph` is also correctly captured. Args: cond_graph: function.FuncGraph. The forward-pass function. @@ -296,24 +297,23 @@ def _resolve_grad_inputs(cond_graph, grad_graph): new_inputs = [] for t in grad_graph.external_captures: + # `t` must either be in `grad_graph.outer_graph` or in the forward + # `cond_graph`. if t.graph != grad_graph.outer_graph: - # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this - # tensor to the least common ancestor of the `cond_graph` and - # `grad_graph` so that it is "in-scope" for `grad_graph`. - # TODO(srbs): `_is_ancestor` calls may be expensive. Compute the least - # common ancestor once and re-use. - assert _is_ancestor(cond_graph, t.graph) - while not _is_ancestor(grad_graph, t.graph): - assert isinstance(t.graph, _function.FuncGraph) - if t in t.graph.internal_captures: - # TODO(srbs): Consider building a map of internal_captures -> - # external_captures instead of searching for `t` twice. - t = t.graph.external_captures[t.graph.internal_captures.index(t)] - else: - # Note: All intermediate tensors are output by the If op. - # TODO(srbs): .index() calls may be expensive. Optimize. - t = t.graph._if.outputs[t.graph.outputs.index(t)] - assert _is_ancestor(grad_graph, t.graph) + assert t.graph == cond_graph + # `internal_captures` are not treated as intermediates and hence not added + # to If op outputs. So we get the outer tensor corresponding to those + # from the list of `external_captures`. + try: + t = t.graph._if.outputs[t.graph.outputs.index(t)] + except ValueError: + index = t.graph.internal_captures.index(t) + t = t.graph.external_captures[index] + + # Note: We rely on the capturing logic of the gradient If op graph to + # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2 + # and while_v2 handle this while building their gradient functions. + assert t.graph == cond_graph.outer_graph new_inputs.append(t) return new_inputs @@ -492,11 +492,3 @@ def _get_output_shapes(true_graph_outputs, false_graph_outputs): for t_out, f_out in zip(true_graph_outputs, false_graph_outputs) ] return output_shapes - - -def _is_ancestor(graph, maybe_ancestor): - if maybe_ancestor == graph: - return True - if isinstance(graph, _function.FuncGraph): - return _is_ancestor(graph.outer_graph, maybe_ancestor) - return False |