diff options
author | 2017-10-17 09:43:46 -0700 | |
---|---|---|
committer | 2017-10-17 09:47:39 -0700 | |
commit | f8b3ced20f7063b3c8efb0e691f28bef845a05f6 (patch) | |
tree | dcac3a108af41986b988756e82d84ede8c8f90c7 | |
parent | a86a589c8b329176bfbb64552405644cb641d99e (diff) |
Reworks the imperative_grad interface.
PiperOrigin-RevId: 172477878
-rw-r--r-- | tensorflow/python/eager/backprop.py | 47 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 15 | ||||
-rw-r--r-- | tensorflow/python/eager/imperative_grad.py | 19 |
3 files changed, 33 insertions, 48 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 1819fba4cb..61c905f31e 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -681,48 +681,17 @@ def _aggregate_grads(gradients): return ops.IndexedSlices(values, indices, dense_shape) -# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total -# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation -# so as to release the gradient tensor to save memory. -_MIN_AGGREGATE_COUNT = 4 -_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024 - - -def _add_new_grads(gradients, gradients_size, tid, grad): - """Adds a new gradient and maybe aggregate the gradients. - - Args: - gradients: A dict map from tensor id to list of gradients. - gradients_size: A dict map from tensor id to its total units. Might - not be initialized. - tid: Tensor id. - grad: New gradient for the `tid`, either a Tensor or IndexedSlices. - - Raises: - ValueError: if `grad` is neight Tensor nor IndexedSlices. - """ - tensor_grads = gradients[tid] - tensor_grads.append(grad) - if len(tensor_grads) < _MIN_AGGREGATE_COUNT: - return - elif tid not in gradients_size: - if isinstance(grad, ops.Tensor): - size = functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access - elif isinstance(grad, ops.IndexedSlices): - size = functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access - else: - raise ValueError("Unexpected gradient type: %s" % type(grad)) - gradients_size[tid] = size - else: - size = gradients_size[tid] - - # For simplicity, assume each element to be 4 bytes now. - if len(tensor_grads) * size * 4 > _MIN_AGGREGATE_BYTES: - gradients[tid] = [_aggregate_grads(tensor_grads)] +def _num_elements(grad): + """The number of elements in the `grad` tensor.""" + if isinstance(grad, ops.Tensor): + return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access + if isinstance(grad, ops.IndexedSlices): + return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access + raise ValueError("`grad` not a Tensor or IndexedSlices.") _default_vspace = imperative_grad.VSpace( - add_new_grads_fn=_add_new_grads, + num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, tensor_id=ops.tensor_id, zeros=array_ops.zeros, diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 9083e3a712..2645d542c0 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient +from tensorflow.python.eager import imperative_grad from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -442,21 +443,21 @@ class BackpropTest(test.TestCase): # Reduce the aggregation limit, cause the backprop to do some # early aggregation. # pylint: disable=protected-access - old_cnt = backprop._MIN_AGGREGATE_COUNT - old_bytes = backprop._MIN_AGGREGATE_BYTES - backprop._MIN_AGGREGATE_COUNT = 10 - backprop._MIN_AGGREGATE_BYTES = 1 + old_cnt = imperative_grad._MIN_AGGREGATE_COUNT + old_bytes = imperative_grad._MIN_AGGREGATE_BYTES + imperative_grad._MIN_AGGREGATE_COUNT = 10 + imperative_grad._MIN_AGGREGATE_BYTES = 1 _ = backprop.implicit_grad(fn)() self.assertEqual(len(add_n), 6) del add_n[:] # Aggregation is also limited by the memory. - backprop._MIN_AGGREGATE_BYTES = 10000 + imperative_grad._MIN_AGGREGATE_BYTES = 10000 _ = backprop.implicit_grad(fn)() self.assertEqual(len(add_n), 2) - backprop._MIN_AGGREGATE_COUNT = old_cnt - backprop._MIN_AGGREGATE_BYTES = old_bytes + imperative_grad._MIN_AGGREGATE_COUNT = old_cnt + imperative_grad._MIN_AGGREGATE_BYTES = old_bytes # pylint: enable=protected-access context.context().clear_post_execution_callbacks() diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index d30d124040..ce58e661d7 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -120,7 +120,14 @@ def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts): VSpace = collections.namedtuple( "VSpace", - ["add_new_grads_fn", "aggregate_fn", "tensor_id", "zeros", "ones_like"]) + ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones_like"]) + + +# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total +# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation +# so as to release the gradient tensor to save memory. +_MIN_AGGREGATE_COUNT = 4 +_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024 def imperative_grad( @@ -193,7 +200,15 @@ def imperative_grad( in_gradients = op_trace.backward_function(*(out_gradients)) for i, t in enumerate(op_trace.input_ids): if in_gradients[i] is not None: - vspace.add_new_grads_fn(gradients, gradients_size, t, in_gradients[i]) + t_grads = gradients.setdefault(t, []) + t_grads.append(in_gradients[i]) + if len(t_grads) >= _MIN_AGGREGATE_COUNT: + if t not in gradients_size: + gradients_size[t] = vspace.num_elements_fn(t_grads[-1]) + size = gradients_size[t] + + if len(t_grads) * size * 4 > _MIN_AGGREGATE_BYTES: + t_grads[:] = [vspace.aggregate_fn(t_grads)] if tensor_usage_counts.get(t, 0) > 0: tensor_usage_counts[t] -= 1 if (t in tensor_to_op |