diff options
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 57 |
1 files changed, 44 insertions, 13 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index ec9a185b73..ed54b8e12e 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -24,11 +24,11 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient +from tensorflow.python.eager import imperative_grad from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -41,6 +41,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.training import training +from tensorflow.python.util import compat class BackpropTest(test.TestCase): @@ -102,18 +103,6 @@ class BackpropTest(test.TestCase): grad_fn = backprop.gradients_function(f) self.assertAllEqual(2., grad_fn(1., dy=2.)[0]) - def testErrors(self): - - @custom_gradient.custom_gradient - def f(x): - def grad(_): - raise RuntimeError('x') - return x, grad - - # TODO(apassos) raise the right error here - with self.assertRaises(errors_impl.InternalError): - backprop.gradients_function(f)(constant_op.constant(1.0)) - def testImplicitGradOverEmbeddingLookup(self): batch_size = 8 embedding_size = 512 @@ -494,6 +483,48 @@ class BackpropTest(test.TestCase): initial_value=1., name='testSameObjectForMultipleArguments.Variable') self.assertAllEqual([1., 1.], np_g(v, v)) + def testEarlyGradAggregation(self): + # Needs to be a list so mutations by the callback affect this function. + add_n = [] + def callback(op_type, unused_1, unused_2, unused_3, unused_4): + if compat.as_bytes(op_type) == compat.as_bytes('AddN'): + add_n.append(1) + context.context().add_post_execution_callback(callback) + + v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0), + name='v') + def fn(): + outputs = [] + for _ in range(20): + outputs.append(v * constant_op.constant(2.0)) + return math_ops.add_n(outputs) + + # By default the aggregation count is 2. + _ = backprop.implicit_grad(fn)()[0][1] + self.assertEqual(len(add_n), 2) + del add_n[:] + + # Reduce the aggregation limit, cause the backprop to do some + # early aggregation. + # pylint: disable=protected-access + old_cnt = imperative_grad._MIN_AGGREGATE_COUNT + old_bytes = imperative_grad._MIN_AGGREGATE_BYTES + imperative_grad._MIN_AGGREGATE_COUNT = 10 + imperative_grad._MIN_AGGREGATE_BYTES = 1 + _ = backprop.implicit_grad(fn)() + self.assertEqual(len(add_n), 6) + del add_n[:] + + # Aggregation is also limited by the memory. + imperative_grad._MIN_AGGREGATE_BYTES = 10000 + _ = backprop.implicit_grad(fn)() + self.assertEqual(len(add_n), 2) + + imperative_grad._MIN_AGGREGATE_COUNT = old_cnt + imperative_grad._MIN_AGGREGATE_BYTES = old_bytes + # pylint: enable=protected-access + context.context().clear_post_execution_callbacks() + def testImplicitGradientsCustomGradientAndCachedVariableValue(self): @custom_gradient.custom_gradient |