aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradients_impl.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-06-29 14:18:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 14:20:16 -0700
commit834da2c3fddab1bbbce742db572cfe65dd320fcd (patch)
treebe313294dc4a90254a0b55c08da40bf3a81d8ff2 /tensorflow/python/ops/gradients_impl.py
parenta4e79e23bace78e3d89d8273828f9d82ad6f1b95 (diff)
Allow gradients() calls from inside a function wrt captured tensors.
The overall approach is to teach the gradients code how to traverse the implicit edges between captured external tensors and ops inside the function body. PiperOrigin-RevId: 202705929
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r--tensorflow/python/ops/gradients_impl.py129
1 files changed, 106 insertions, 23 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 889a00190e..713a8ab2cc 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -31,6 +31,7 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -113,12 +114,14 @@ ops.register_tensor_conversion_function(ops.IndexedSlices,
_IndexedSlicesToTensor)
-def _MarkReachedOps(from_ops, reached_ops):
+def _MarkReachedOps(from_ops, reached_ops, func_graphs):
"""Mark all ops reached from "from_ops".
Args:
from_ops: list of Operations.
reached_ops: set of Operations.
+ func_graphs: list of function._FuncGraphs. This method will traverse through
+ these functions if they capture from_ops or any reachable ops.
"""
queue = collections.deque()
queue.extend(from_ops)
@@ -128,10 +131,11 @@ def _MarkReachedOps(from_ops, reached_ops):
reached_ops.add(op)
for output in op.outputs:
if _IsBackpropagatable(output):
- queue.extend(output.consumers())
+ queue.extend(_Consumers(output, func_graphs))
-def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
+def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
+ xs):
"""Initialize the pending count for ops between two lists of Operations.
'pending_count[op]' indicates the number of backprop inputs
@@ -141,6 +145,11 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
to_ops: list of Operations.
from_ops: list of Operations.
colocate_gradients_with_ops: Python bool. See docstring of gradients().
+ func_graphs: list of function._FuncGraphs. This method will traverse through
+ these functions if they capture from_ops or any reachable ops. This is
+ useful if to_ops occur in a function and from_ops are in an outer function
+ or graph.
+ xs: list of Tensors.
Returns:
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
@@ -151,7 +160,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
"""
# Mark reachable ops from from_ops.
reached_ops = set()
- _MarkReachedOps(from_ops, reached_ops)
+ _MarkReachedOps(from_ops, reached_ops, func_graphs)
# X in reached_ops iff X is reachable from from_ops by a path of zero or more
# backpropagatable tensors.
@@ -170,7 +179,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
between_op_list.append(op)
# Clear the boolean so we won't add the inputs again.
reached_ops.remove(op)
- for inp in op.inputs:
+ for inp in _Inputs(op, xs):
queue.append(inp.op)
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops
@@ -182,7 +191,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
# Initialize pending count for between ops.
pending_count = collections.defaultdict(int)
for op in between_op_list:
- for x in op.inputs:
+ for x in _Inputs(op, xs):
if x.op in between_ops:
pending_count[x.op] += 1
@@ -303,7 +312,7 @@ def _VerifyGeneratedGradients(grads, op):
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
-def _StopOps(from_ops, stop_gradient_ops, pending_count):
+def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
"""The set of ops that terminate the gradient computation.
This computes the frontier of the forward graph *before* which backprop
@@ -319,6 +328,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
from_ops: list of Operations.
stop_gradient_ops: list of Operations never to backprop through.
pending_count: mapping from operation to number of backprop inputs.
+ xs: list of Tensors.
Returns:
The set of operations.
@@ -326,7 +336,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
stop_ops = set()
for op in from_ops:
is_stop_op = True
- for inp in op.inputs:
+ for inp in _Inputs(op, xs):
if pending_count[inp.op] > 0:
is_stop_op = False
break
@@ -346,10 +356,10 @@ def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pyli
yield
-def _SymGrad(op, out_grads):
+def _SymGrad(op, out_grads, xs):
"""Backprop through a function call node op given its outputs' gradients."""
- f_in = [x for x in op.inputs] + out_grads
- f_types = [x.dtype for x in op.inputs]
+ f_in = [x for x in _Inputs(op, xs)] + out_grads
+ f_types = [x.dtype for x in _Inputs(op, xs)]
f = attr_value_pb2.NameAttrList()
f.name = op.type
for k in op.node_def.attr:
@@ -399,7 +409,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
return grad_fn()
-def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
+def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
"""Raises an error if we backprop through a loop var."""
# Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
# message.
@@ -413,7 +423,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
if curr_op in from_ops:
target_op = curr_op
break
- queue.extend(t.op for t in curr_op.inputs)
+ queue.extend(t.op for t in _Inputs(curr_op, xs))
assert target_op
raise ValueError(
"Cannot compute gradient inside while loop with respect to op '%s'. "
@@ -423,6 +433,68 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
% target_op.name)
+def _MaybeCaptured(t):
+ """If t is a captured value placeholder, returns the original captured value.
+
+ Args:
+ t: Tensor
+
+ Returns:
+ A tensor, potentially from a different Graph/function._FuncGraph.
+ """
+ # pylint: disable=protected-access
+ if isinstance(t.op.graph, function._FuncGraph) and t.op.type == "Placeholder":
+ for input_t, placeholder_t in t.op.graph._captured.items():
+ if t == placeholder_t:
+ return _MaybeCaptured(input_t)
+ # pylint: enable=protected-access
+ return t
+
+
+# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
+# _GradientsHelper a class with xs as a member variable.
+def _Inputs(op, xs):
+ """Returns the inputs of op, crossing closure boundaries where necessary.
+
+ Args:
+ op: Operation
+ xs: list of Tensors we are differentiating w.r.t.
+
+ Returns:
+ A list of tensors. The tensors may be from multiple
+ Graph/function._FuncGraphs if op is in a function._FuncGraph and has
+ captured inputs.
+ """
+ if isinstance(op.graph, function._FuncGraph): # pylint: disable=protected-access
+ # If we're differentiating w.r.t. `t`, do not attempt to traverse through it
+ # to a captured value. The algorithm needs to "see" `t` in this case, even
+ # if it's a function input for a captured value, whereas usually we'd like
+ # to traverse through these closures as if the captured value was the direct
+ # input to op.
+ return [t if (t in xs) else _MaybeCaptured(t) for t in op.inputs]
+ else:
+ return op.inputs
+
+
+def _Consumers(t, func_graphs):
+ """Returns the consumers of t, crossing closure boundaries where necessary.
+
+ Args:
+ t: Tensor
+ func_graphs: a list of function._FuncGraphs that may have captured t.
+
+ Returns:
+ A list of tensors. The tensors will be from the current graph and/or
+ func_graphs.
+ """
+ consumers = t.consumers()
+ for func in func_graphs:
+ for input_t, placeholder in func._captured.items(): # pylint: disable=protected-access
+ if input_t == t:
+ consumers.extend(_Consumers(placeholder, func_graphs))
+ return consumers
+
+
@tf_export("gradients")
def gradients(ys,
xs,
@@ -532,6 +604,14 @@ def _GradientsHelper(ys,
if src_graph is None:
src_graph = ops.get_default_graph()
+ # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
+ # ancestor graphs. This is necessary for correctly handling captured values.
+ func_graphs = []
+ curr_graph = src_graph
+ while isinstance(curr_graph, function._FuncGraph): # pylint: disable=protected-access
+ func_graphs.append(curr_graph)
+ curr_graph = curr_graph._outer_graph # pylint: disable=protected-access
+
ys = _AsList(ys)
xs = _AsList(xs)
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
@@ -566,12 +646,13 @@ def _GradientsHelper(ys,
# Initialize the pending count for ops in the connected subgraph from ys
# to the xs.
if len(ys) > 1:
- ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
+ ys = [array_ops.identity(y) if _Consumers(y, func_graphs) else y
+ for y in ys]
to_ops = [t.op for t in ys]
from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients]
reachable_to_ops, pending_count, loop_state = _PendingCount(
- to_ops, from_ops, colocate_gradients_with_ops)
+ to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
# Iterate over the collected ops.
#
@@ -605,7 +686,7 @@ def _GradientsHelper(ys,
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
- stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
+ stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
@@ -654,7 +735,7 @@ def _GradientsHelper(ys,
op._control_flow_context.IsWhileContext() and
op._control_flow_context ==
ops.get_default_graph()._get_control_flow_context()):
- _RaiseNoGradWrtInitialLoopValError(op, from_ops)
+ _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
# pylint: enable=protected-access
if (grad_fn or is_func_call) and has_out_grads:
@@ -686,7 +767,7 @@ def _GradientsHelper(ys,
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
in_grads = _MaybeCompile(grad_scope, op, func_call,
- lambda: _SymGrad(op, out_grads))
+ lambda: _SymGrad(op, out_grads, xs))
in_grads = _AsList(in_grads)
_VerifyGeneratedGradients(in_grads, op)
if gate_gradients and len([x for x in in_grads
@@ -701,8 +782,8 @@ def _GradientsHelper(ys,
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards.
- in_grads = [None] * len(op.inputs)
- for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)):
+ in_grads = [None] * len(_Inputs(op, xs))
+ for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
if in_grad is not None:
if (isinstance(in_grad, ops.Tensor) and
t_in.dtype != dtypes.resource):
@@ -720,7 +801,8 @@ def _GradientsHelper(ys,
loop_state.ExitGradWhileContext(op, before=False)
# Update pending count for the inputs of op and enqueue ready ops.
- _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state)
+ _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
+ xs)
if loop_state:
loop_state.PostProcessing()
@@ -739,9 +821,10 @@ def _HasAnyNotNoneGrads(grads, op):
return False
-def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
+def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
+ xs):
"""Update pending count for the inputs of op and enqueue ready ops."""
- for x in op.inputs:
+ for x in _Inputs(op, xs):
pending_count[x.op] -= 1
ready = (pending_count[x.op] == 0)
if loop_state and not ready: