aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-17 09:43:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-17 09:47:39 -0700
commitf8b3ced20f7063b3c8efb0e691f28bef845a05f6 (patch)
treedcac3a108af41986b988756e82d84ede8c8f90c7
parenta86a589c8b329176bfbb64552405644cb641d99e (diff)
Reworks the imperative_grad interface.
PiperOrigin-RevId: 172477878
-rw-r--r--tensorflow/python/eager/backprop.py47
-rw-r--r--tensorflow/python/eager/backprop_test.py15
-rw-r--r--tensorflow/python/eager/imperative_grad.py19
3 files changed, 33 insertions, 48 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 1819fba4cb..61c905f31e 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -681,48 +681,17 @@ def _aggregate_grads(gradients):
return ops.IndexedSlices(values, indices, dense_shape)
-# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
-# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
-# so as to release the gradient tensor to save memory.
-_MIN_AGGREGATE_COUNT = 4
-_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
-
-
-def _add_new_grads(gradients, gradients_size, tid, grad):
- """Adds a new gradient and maybe aggregate the gradients.
-
- Args:
- gradients: A dict map from tensor id to list of gradients.
- gradients_size: A dict map from tensor id to its total units. Might
- not be initialized.
- tid: Tensor id.
- grad: New gradient for the `tid`, either a Tensor or IndexedSlices.
-
- Raises:
- ValueError: if `grad` is neight Tensor nor IndexedSlices.
- """
- tensor_grads = gradients[tid]
- tensor_grads.append(grad)
- if len(tensor_grads) < _MIN_AGGREGATE_COUNT:
- return
- elif tid not in gradients_size:
- if isinstance(grad, ops.Tensor):
- size = functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
- elif isinstance(grad, ops.IndexedSlices):
- size = functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
- else:
- raise ValueError("Unexpected gradient type: %s" % type(grad))
- gradients_size[tid] = size
- else:
- size = gradients_size[tid]
-
- # For simplicity, assume each element to be 4 bytes now.
- if len(tensor_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
- gradients[tid] = [_aggregate_grads(tensor_grads)]
+def _num_elements(grad):
+ """The number of elements in the `grad` tensor."""
+ if isinstance(grad, ops.Tensor):
+ return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
+ if isinstance(grad, ops.IndexedSlices):
+ return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
+ raise ValueError("`grad` not a Tensor or IndexedSlices.")
_default_vspace = imperative_grad.VSpace(
- add_new_grads_fn=_add_new_grads,
+ num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
tensor_id=ops.tensor_id,
zeros=array_ops.zeros,
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 9083e3a712..2645d542c0 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
+from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
@@ -442,21 +443,21 @@ class BackpropTest(test.TestCase):
# Reduce the aggregation limit, cause the backprop to do some
# early aggregation.
# pylint: disable=protected-access
- old_cnt = backprop._MIN_AGGREGATE_COUNT
- old_bytes = backprop._MIN_AGGREGATE_BYTES
- backprop._MIN_AGGREGATE_COUNT = 10
- backprop._MIN_AGGREGATE_BYTES = 1
+ old_cnt = imperative_grad._MIN_AGGREGATE_COUNT
+ old_bytes = imperative_grad._MIN_AGGREGATE_BYTES
+ imperative_grad._MIN_AGGREGATE_COUNT = 10
+ imperative_grad._MIN_AGGREGATE_BYTES = 1
_ = backprop.implicit_grad(fn)()
self.assertEqual(len(add_n), 6)
del add_n[:]
# Aggregation is also limited by the memory.
- backprop._MIN_AGGREGATE_BYTES = 10000
+ imperative_grad._MIN_AGGREGATE_BYTES = 10000
_ = backprop.implicit_grad(fn)()
self.assertEqual(len(add_n), 2)
- backprop._MIN_AGGREGATE_COUNT = old_cnt
- backprop._MIN_AGGREGATE_BYTES = old_bytes
+ imperative_grad._MIN_AGGREGATE_COUNT = old_cnt
+ imperative_grad._MIN_AGGREGATE_BYTES = old_bytes
# pylint: enable=protected-access
context.context().clear_post_execution_callbacks()
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index d30d124040..ce58e661d7 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -120,7 +120,14 @@ def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts):
VSpace = collections.namedtuple(
"VSpace",
- ["add_new_grads_fn", "aggregate_fn", "tensor_id", "zeros", "ones_like"])
+ ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones_like"])
+
+
+# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
+# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
+# so as to release the gradient tensor to save memory.
+_MIN_AGGREGATE_COUNT = 4
+_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
def imperative_grad(
@@ -193,7 +200,15 @@ def imperative_grad(
in_gradients = op_trace.backward_function(*(out_gradients))
for i, t in enumerate(op_trace.input_ids):
if in_gradients[i] is not None:
- vspace.add_new_grads_fn(gradients, gradients_size, t, in_gradients[i])
+ t_grads = gradients.setdefault(t, [])
+ t_grads.append(in_gradients[i])
+ if len(t_grads) >= _MIN_AGGREGATE_COUNT:
+ if t not in gradients_size:
+ gradients_size[t] = vspace.num_elements_fn(t_grads[-1])
+ size = gradients_size[t]
+
+ if len(t_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
+ t_grads[:] = [vspace.aggregate_fn(t_grads)]
if tensor_usage_counts.get(t, 0) > 0:
tensor_usage_counts[t] -= 1
if (t in tensor_to_op