aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py57
1 files changed, 44 insertions, 13 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index ec9a185b73..ed54b8e12e 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -24,11 +24,11 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
+from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -41,6 +41,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import training
+from tensorflow.python.util import compat
class BackpropTest(test.TestCase):
@@ -102,18 +103,6 @@ class BackpropTest(test.TestCase):
grad_fn = backprop.gradients_function(f)
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
- def testErrors(self):
-
- @custom_gradient.custom_gradient
- def f(x):
- def grad(_):
- raise RuntimeError('x')
- return x, grad
-
- # TODO(apassos) raise the right error here
- with self.assertRaises(errors_impl.InternalError):
- backprop.gradients_function(f)(constant_op.constant(1.0))
-
def testImplicitGradOverEmbeddingLookup(self):
batch_size = 8
embedding_size = 512
@@ -494,6 +483,48 @@ class BackpropTest(test.TestCase):
initial_value=1., name='testSameObjectForMultipleArguments.Variable')
self.assertAllEqual([1., 1.], np_g(v, v))
+ def testEarlyGradAggregation(self):
+ # Needs to be a list so mutations by the callback affect this function.
+ add_n = []
+ def callback(op_type, unused_1, unused_2, unused_3, unused_4):
+ if compat.as_bytes(op_type) == compat.as_bytes('AddN'):
+ add_n.append(1)
+ context.context().add_post_execution_callback(callback)
+
+ v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0),
+ name='v')
+ def fn():
+ outputs = []
+ for _ in range(20):
+ outputs.append(v * constant_op.constant(2.0))
+ return math_ops.add_n(outputs)
+
+ # By default the aggregation count is 2.
+ _ = backprop.implicit_grad(fn)()[0][1]
+ self.assertEqual(len(add_n), 2)
+ del add_n[:]
+
+ # Reduce the aggregation limit, cause the backprop to do some
+ # early aggregation.
+ # pylint: disable=protected-access
+ old_cnt = imperative_grad._MIN_AGGREGATE_COUNT
+ old_bytes = imperative_grad._MIN_AGGREGATE_BYTES
+ imperative_grad._MIN_AGGREGATE_COUNT = 10
+ imperative_grad._MIN_AGGREGATE_BYTES = 1
+ _ = backprop.implicit_grad(fn)()
+ self.assertEqual(len(add_n), 6)
+ del add_n[:]
+
+ # Aggregation is also limited by the memory.
+ imperative_grad._MIN_AGGREGATE_BYTES = 10000
+ _ = backprop.implicit_grad(fn)()
+ self.assertEqual(len(add_n), 2)
+
+ imperative_grad._MIN_AGGREGATE_COUNT = old_cnt
+ imperative_grad._MIN_AGGREGATE_BYTES = old_bytes
+ # pylint: enable=protected-access
+ context.context().clear_post_execution_callbacks()
+
def testImplicitGradientsCustomGradientAndCachedVariableValue(self):
@custom_gradient.custom_gradient