aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/cond_v2_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/cond_v2_impl.py')
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py139
1 files changed, 94 insertions, 45 deletions
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index d310f83dca..5cd0cb34de 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -135,6 +135,10 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
def _IfGrad(op, *grads): # pylint: disable=invalid-name
"""The gradient of an If op produced by cond_v2."""
true_graph, false_graph = _get_func_graphs(op)
+ # Note: op.graph != ops.get_default_graph() when we are computing the gradient
+ # of a nested cond.
+ assert true_graph._outer_graph == op.graph
+ assert false_graph._outer_graph == op.graph
# Create grad functions that compute the gradient of the true/false forward
# graphs. These functions will capture tensors from the forward pass
@@ -147,15 +151,16 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
assert ([t.dtype for t in true_grad_graph.outputs] ==
[t.dtype for t in false_grad_graph.outputs])
- # Match up the captured grad function inputs with outputs of 'op' and other
- # external tensors.
- true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph)
- false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph)
+ # Resolve references to forward graph tensors in grad graphs and ensure
+ # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
+ true_grad_extra_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
+ false_grad_extra_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
# Make the inputs to true_grad_graph and false_grad_graph match. Note that
# this modifies true_grad_graph and false_grad_graph.
grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
- true_grad_inputs, false_grad_inputs)
+ true_grad_extra_inputs,
+ false_grad_extra_inputs)
# Add all intermediate tensors as function outputs so they're available for
# higher-order gradient computations.
@@ -199,11 +204,20 @@ def _get_func_graphs(if_op):
input_shapes = [t.shape for t in extra_inputs]
func_name = if_op.get_attr(branch_name).name
fdef = if_op.graph._get_function(func_name).definition
- func_graph = _function_def_to_graph.function_def_to_graph(
- fdef, input_shapes)
+ # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
+ # in the case of nested if ops or when the gradient is being computed
+ # from inside a Defun. We build the `func_graph` with `if_op.graph` as its
+ # `outer_graph`. This resembles how the `_FuncGraph` was built in the
+ # forward pass. We need this so that we can resolve references to tensors
+ # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
+ with if_op.graph.as_default():
+ func_graph = _function_def_to_graph.function_def_to_graph(
+ fdef, input_shapes)
func_graph.extra_inputs = extra_inputs
func_graph.extra_args = func_graph.inputs
func_graph._captured = dict(zip(extra_inputs, func_graph.inputs))
+ # Set the if op so that the gradient code can use it.
+ func_graph._if = if_op
return func_graph
return (_get_func_graph_for_branch("then_branch"),
@@ -240,7 +254,7 @@ def _grad_fn(func_graph, grads):
# Build the gradient graph. Note that this builds the gradient computation of
# func_graph in the current graph, which requires capturing tensors from
# func_graph. The captured func_graph tensors are resolved to external tensors
- # in _get_grad_inputs.
+ # in _resolve_grad_inputs.
result = _gradients_impl._GradientsHelper(
ys, func_graph.inputs, grad_ys=grad_ys,
src_graph=func_graph)
@@ -261,43 +275,49 @@ def _create_grad_func(func_graph, grads, name):
[], [], name)
-def _get_grad_inputs(if_op, cond_graph, grad_graph):
- """Returns the tensors we should pass to grad_graph.
+def _resolve_grad_inputs(cond_graph, grad_graph):
+ """Returns the tensors to pass as `extra_inputs` to `grad_graph`.
- This method handles tensors captured from cond_graph in grad_graph. It
- converts these to suitable input tensors from the outer graph.
+ The `grad_graph` may have external references to
+ 1. Its outer graph containing the input gradients. These references are kept
+ as is.
+ 2. Tensors in the forward pass graph. These tensors may not be "live"
+ when the gradient is being computed. We replace such references by their
+ corresponding tensor in the least common ancestor graph of `grad_graph` and
+ `cond_graph`. Since we export intermediate tensors for all branch
+ functions, this is always possible.
Args:
- if_op: Operation. The forward-pass If op that uses cond_graph.
cond_graph: function._FuncGraph. The forward-pass function.
grad_graph: function._FuncGraph. The gradients function.
Returns:
A list of inputs tensors to be passed to grad_graph.
"""
- inputs = []
-
- # Maps placeholders in cond_graph -> input tensor in outer graph.
- forward_input_map = {v: k for k, v in cond_graph._captured.items()}
+ new_extra_inputs = []
for t in grad_graph.extra_inputs:
- if t.graph == ops.get_default_graph():
- # t is in the outer graph (e.g. one of the input gradients).
- inputs.append(t)
- elif t in forward_input_map:
- # t is an input placeholder in cond_graph. Get the corresponding input
- # tensor in the outer graph.
- assert t.graph == cond_graph
- assert forward_input_map[t].graph == ops.get_default_graph()
- inputs.append(forward_input_map[t])
- else:
- # t is an intermediate value in cond_graph. Get the corresponding output
- # of 'if_op' (note that all intermediate values are outputs).
- assert t.graph == cond_graph
- output_idx = cond_graph.outputs.index(t)
- inputs.append(if_op.outputs[output_idx])
-
- return inputs
+ if t.graph != grad_graph._outer_graph:
+ # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
+ # tensor to the least common ancestor of the `cond_graph` and
+ # `grad_graph` so that it is "in-scope" for `grad_graph`.
+ # TODO(srbs): `_is_ancestor` calls may be expensive. Compute the least
+ # common ancestor once and re-use.
+ assert _is_ancestor(cond_graph, t.graph)
+ while not _is_ancestor(grad_graph, t.graph):
+ assert isinstance(t.graph, _function._FuncGraph)
+ if t in t.graph.extra_args:
+ # TODO(srbs): Consider building a map of extra_args -> extra_inputs.
+ # instead of searching for `t` twice.
+ t = t.graph.extra_inputs[t.graph.extra_args.index(t)]
+ else:
+ # Note: All intermediate tensors are output by the If op.
+ # TODO(srbs): .index() calls may be expensive. Optimize.
+ t = t.graph._if.outputs[t.graph.outputs.index(t)]
+ assert _is_ancestor(grad_graph, t.graph)
+ new_extra_inputs.append(t)
+
+ return new_extra_inputs
def _create_new_tf_function(func_graph):
@@ -326,7 +346,8 @@ def _create_new_tf_function(func_graph):
# a new TF_Function that we add to the graph.
fdef = _function.function_def_from_tf_function(c_func)
defined_func = _function._from_definition(fdef)
- defined_func.add_to_graph(ops.get_default_graph())
+ defined_func._sub_functions = func_graph._functions
+ defined_func.add_to_graph(func_graph._outer_graph)
return func_graph.name
@@ -389,7 +410,8 @@ def _pad_params(true_graph, false_graph, true_params, false_params):
return new_true_params, new_false_inputs
-def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
+def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
+ false_extra_inputs):
"""Modifies true_graph and false_graph so they have the same input signature.
This method reorders and/or adds parameters to true_graph and false_graph so
@@ -400,9 +422,9 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
Args:
true_graph: function._FuncGraph
false_graph: function._FuncGraph
- true_inputs: a list of Tensors in the outer graph. The inputs for
+ true_extra_inputs: a list of Tensors in the outer graph. The inputs for
true_graph.
- false_inputs: a list of Tensors in the outer graph. The inputs for
+ false_extra_inputs: a list of Tensors in the outer graph. The inputs for
false_graph.
Returns:
@@ -411,12 +433,12 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
false_inputs.
"""
shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs(
- true_inputs, false_inputs)
+ true_extra_inputs, false_extra_inputs)
new_inputs = shared_inputs + true_only_inputs + false_only_inputs
- true_input_to_param = dict(zip(true_inputs, true_graph.inputs))
- false_input_to_param = dict(zip(false_inputs, false_graph.inputs))
+ true_input_to_param = dict(zip(true_extra_inputs, true_graph.inputs))
+ false_input_to_param = dict(zip(false_extra_inputs, false_graph.inputs))
true_graph.inputs = (
[true_input_to_param[t] for t in shared_inputs] +
@@ -432,6 +454,9 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
true_graph.extra_inputs = new_inputs
false_graph.extra_inputs = new_inputs
+ true_graph.extra_args = true_graph.inputs
+ false_graph.extra_args = false_graph.inputs
+
true_graph._captured = dict(zip(new_inputs, true_graph.inputs))
false_graph._captured = dict(zip(new_inputs, false_graph.inputs))
@@ -454,14 +479,30 @@ def _create_dummy_params(func_graph, template_tensors):
def _get_grad_fn_name(func_graph):
- """Returns a unique name to use for the grad function of `func_graph`."""
+ """Returns a unique name to use for the grad function of `func_graph`.
+
+ Ensures this name is unique in the entire hierarchy.
+
+ Args:
+ func_graph: The _FuncGraph.
+
+ Returns:
+ A string, the name to use for the gradient function.
+ """
name = "%s_grad" % func_graph.name
base_name = name
counter = 1
- if ops.get_default_graph()._is_function(name):
- name = "%s_%s" % (base_name, counter)
- counter += 1
+ has_conflict = True
+ while has_conflict:
+ curr_graph = func_graph._outer_graph
+ has_conflict = curr_graph._is_function(name)
+ while not has_conflict and isinstance(curr_graph, _function._FuncGraph):
+ curr_graph = curr_graph._outer_graph
+ has_conflict = curr_graph._is_function(name)
+ if has_conflict:
+ name = "%s_%s" % (base_name, counter)
+ counter += 1
return name
@@ -477,3 +518,11 @@ def _check_same_outputs(true_graph, false_graph):
"arguments, got:\n"
" true_fn: %s\n"
" false_fn: %s" % (true_output_types, false_output_types))
+
+
+def _is_ancestor(graph, maybe_ancestor):
+ if maybe_ancestor == graph:
+ return True
+ if isinstance(graph, _function._FuncGraph):
+ return _is_ancestor(graph._outer_graph, maybe_ancestor)
+ return False