aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/imperative_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/imperative_grad.py')
-rw-r--r--tensorflow/python/eager/imperative_grad.py196
1 files changed, 190 insertions, 6 deletions
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 837cad974a..c87719f84a 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -20,13 +20,114 @@ from __future__ import print_function
import collections
-from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.framework import errors
+from tensorflow.python.eager import tape as tape_module
+
+
+# Terminology:
+#
+# - op: a possibly composite operation, which has an entry in the tape
+# - target: dy in dx/dy
+# - source: dx in dx/dy
+# - tensor: one of the many inputs or outputs of an operation
+#
+# Below here we do the gradient algorithm. It works as follows:
+#
+# First we filter the tape to just the subset of operations we want to
+# differentiate. In the process of doing so we count how many times each Tensor
+# is used as an input to an op (so we know when we're done computing gradients
+# for that Tensor). We also count, for each tape entry, how many of its output
+# Tensors need gradients to be computed (Tensors which are not used do not need
+# any gradients to be computed).
+#
+# Finally, we start a backprop stack with a set of tape entries for which we
+# have all gradients available. This set usually is a subset of the set of
+# targets (not all since targets which have outputs in the tape will not have
+# gradients available initially).
+#
+# Then we repeatedly pop an entry from the stack, run its backprop, and update
+# the gradients of its inputs. Once we have computed all gradients for a single
+# input we can mark this input as done, and this can trigger adding an entry to
+# the stack if all outputs of that entry are now done.
+#
+# When the stack is empty we have gradients for all tensors we're interested in.
+def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources):
+ """Filters the tape to only include relevant entries and counts tensor usages.
+
+ Args:
+ vspace: information about the space we're differentiating in.
+ target: the target to optimize.
+ tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
+ op_to_entry: Map from op id to a tape.TapeEntry object
+ id_sources: the ids of the sources wrt the gradient is being taken.
+
+ Returns:
+ usage counts (how many entries downstream from a tensor use it)
+ op_to_entry_map: entry map (a filtered tape, with only the relevant
+ entries),
+ missing: map from tensor id to how many downstream gradients still need
+ to be computed before this tensor's gradient can be computed.
+ """
+ tensor_stack = [vspace.tensor_id(x) for x in target]
+ tensor_usage_counts = {}
+ o_to_e = {} # Copy of just the bits we need from op_to_entry
+ while tensor_stack:
+ t = tensor_stack.pop()
+ op = tensor_to_op.get(t, None)
+ # op is None or -1 if the tensor is a source (i.e. was watched directly)
+ if op is None or op == -1 or op in o_to_e:
+ continue
+ op_trace = tape_module.TapeEntry(*op_to_entry[op])
+ o_to_e[op] = op_trace
+ for it in op_trace.input_ids:
+ if it in tensor_usage_counts:
+ tensor_usage_counts[it] += 1
+ else:
+ tensor_usage_counts[it] = 1
+ if it not in id_sources and it in tensor_to_op:
+ tensor_stack.append(it)
+ op_missing_tensor_counts = collections.defaultdict(int)
+ for t in tensor_usage_counts:
+ if t in tensor_to_op and tensor_to_op[t] is not None:
+ op_missing_tensor_counts[tensor_to_op[t]] += 1
+ return tensor_usage_counts, o_to_e, op_missing_tensor_counts
+
+
+def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
+ """Returns the set of tape entries which are available for backprop."""
+ ready_ops = []
+ for op in op_to_entry:
+ if op not in op_missing_tensor:
+ ready_ops.append(op)
+ return ready_ops
+
+
+def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts):
+ """Computes the initial gradients for each Tensor."""
+ # Initialize the backprop stack
+ gradients = collections.defaultdict(list)
+ for i, t in enumerate(target):
+ if vspace.tensor_id(t) in tensor_usage_counts:
+ # Can't provide a gradient of something we're trying to differentiate
+ assert output_gradients is None or output_gradients[i] is None
+ else:
+ if output_gradients is None or output_gradients[i] is None:
+ out_grad = vspace.ones_like(t)
+ else:
+ out_grad = output_gradients[i]
+ gradients[vspace.tensor_id(t)].append(out_grad)
+ return gradients
VSpace = collections.namedtuple(
"VSpace",
- ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"])
+ ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones_like"])
+
+
+# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
+# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
+# so as to release the gradient tensor to save memory.
+_MIN_AGGREGATE_COUNT = 4
+_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
def imperative_grad(
@@ -60,6 +161,89 @@ def imperative_grad(
or if only non-differentiable functions of the source were used in the
computation of target.
"""
- with errors.raise_exception_on_not_ok_status() as status:
- return pywrap_tensorflow.TFE_Py_TapeGradient(
- tape._tape, vspace, target, sources, output_gradients, status) # pylint: disable=protected-access
+ tensor_to_op, op_to_entry = tape.export()
+ # This overwrites the op_to_entry variable, which will release all memory used
+ # to keep traces that are irrelevant to the gradient computation we're doing
+ # here.
+ id_sources = [vspace.tensor_id(t) for t in sources]
+ tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
+ vspace, target, tensor_to_op, op_to_entry, id_sources)
+ ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
+ gradients = _initial_gradients(vspace, target, output_gradients,
+ tensor_usage_counts)
+ gradients_size = dict()
+ # Now exhaust the backprop stack
+ while ready_ops:
+ op = ready_ops.pop()
+ op_trace = op_to_entry.pop(op)
+ out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
+
+ # Cache the last used zero tensor. We reuse it if the next one
+ # we need is of the same shape and dtype. This is very helpful in
+ # large splits and should have negligible overhead in other cases.
+ last_shape_and_dtype = None
+ last_zeros = None
+ for i in range(len(out_gradients)):
+ if out_gradients[i] is None:
+ # TODO(apassos) this should be in the right device
+ none_indices = _grad_fn_accepts_none_for_indices.get(
+ op_trace.op_type, None)
+ if none_indices is None or i not in none_indices:
+ shape_and_dtype = op_trace.output_shape_and_dtype[i]
+ if shape_and_dtype != last_shape_and_dtype:
+ last_shape_and_dtype = shape_and_dtype
+ last_zeros = vspace.zeros(*shape_and_dtype)
+ out_gradients[i] = last_zeros
+ else:
+ out_gradients[i] = vspace.aggregate_fn(out_gradients[i])
+
+ in_gradients = op_trace.backward_function(*(out_gradients))
+ for i, t in enumerate(op_trace.input_ids):
+ if in_gradients[i] is not None:
+ t_grads = gradients.setdefault(t, [])
+ t_grads.append(in_gradients[i])
+ if len(t_grads) >= _MIN_AGGREGATE_COUNT:
+ if t not in gradients_size:
+ gradients_size[t] = vspace.num_elements_fn(t_grads[-1])
+ size = gradients_size[t]
+
+ if len(t_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
+ t_grads[:] = [vspace.aggregate_fn(t_grads)]
+ if tensor_usage_counts.get(t, 0) > 0:
+ tensor_usage_counts[t] -= 1
+ if (t in tensor_to_op
+ and tensor_usage_counts[t] == 0
+ and t not in id_sources):
+ in_op = tensor_to_op[t]
+ if in_op is None or in_op == -1:
+ continue
+ if op_missing_tensor.get(in_op, 0) > 0:
+ op_missing_tensor[in_op] -= 1
+ if op_missing_tensor.get(in_op, 0) == 0:
+ ready_ops.append(in_op)
+ result = []
+ for i, s in enumerate(sources):
+ g = gradients.get(vspace.tensor_id(s), None)
+ if g is None:
+ result.append(None)
+ else:
+ result.append(vspace.aggregate_fn(g))
+ return result
+
+
+# TODO(agarwal): use an automatic mechanism for handling None arguments to
+# gradient functions.
+# Some gradient functions can accept None arguments for gradients. The following
+# maps the operation name to the indices at which the corresponding gradient
+# function can accept None values.
+# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
+# during backprop. However the gradient function uses only the first of those
+# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
+# indicates that only the gradient corresponding to index 0 is used, and the
+# gradient values at indices 1-4 are ignored (and hence can be None). The
+# backprop algorithm can then leverage this by not constructing zeros to
+# pass for those indices.
+_grad_fn_accepts_none_for_indices = {
+ "SoftmaxCrossEntropyWithLogits": [1],
+ "FusedBatchNorm": [1, 2, 3, 4]
+}