aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-11-08 13:44:26 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:37 -0800
commit2545c4e93b7c1ee21ddb3666580ff4922630d974 (patch)
treeed3da37ca4f30f365822785f2e8b3aa2bf26388f /tensorflow/python/eager/backprop_test.py
parentfd52578963fdc3474be30c38fa9027c1c407301b (diff)
Moves imperative_grad to C
Neutral-to-positive on all benchmarks. Also reduces overhead of should_record. PiperOrigin-RevId: 175057104
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py57
1 files changed, 13 insertions, 44 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index ed54b8e12e..ec9a185b73 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -24,11 +24,11 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
-from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -41,7 +41,6 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import training
-from tensorflow.python.util import compat
class BackpropTest(test.TestCase):
@@ -103,6 +102,18 @@ class BackpropTest(test.TestCase):
grad_fn = backprop.gradients_function(f)
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
+ def testErrors(self):
+
+ @custom_gradient.custom_gradient
+ def f(x):
+ def grad(_):
+ raise RuntimeError('x')
+ return x, grad
+
+ # TODO(apassos) raise the right error here
+ with self.assertRaises(errors_impl.InternalError):
+ backprop.gradients_function(f)(constant_op.constant(1.0))
+
def testImplicitGradOverEmbeddingLookup(self):
batch_size = 8
embedding_size = 512
@@ -483,48 +494,6 @@ class BackpropTest(test.TestCase):
initial_value=1., name='testSameObjectForMultipleArguments.Variable')
self.assertAllEqual([1., 1.], np_g(v, v))
- def testEarlyGradAggregation(self):
- # Needs to be a list so mutations by the callback affect this function.
- add_n = []
- def callback(op_type, unused_1, unused_2, unused_3, unused_4):
- if compat.as_bytes(op_type) == compat.as_bytes('AddN'):
- add_n.append(1)
- context.context().add_post_execution_callback(callback)
-
- v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0),
- name='v')
- def fn():
- outputs = []
- for _ in range(20):
- outputs.append(v * constant_op.constant(2.0))
- return math_ops.add_n(outputs)
-
- # By default the aggregation count is 2.
- _ = backprop.implicit_grad(fn)()[0][1]
- self.assertEqual(len(add_n), 2)
- del add_n[:]
-
- # Reduce the aggregation limit, cause the backprop to do some
- # early aggregation.
- # pylint: disable=protected-access
- old_cnt = imperative_grad._MIN_AGGREGATE_COUNT
- old_bytes = imperative_grad._MIN_AGGREGATE_BYTES
- imperative_grad._MIN_AGGREGATE_COUNT = 10
- imperative_grad._MIN_AGGREGATE_BYTES = 1
- _ = backprop.implicit_grad(fn)()
- self.assertEqual(len(add_n), 6)
- del add_n[:]
-
- # Aggregation is also limited by the memory.
- imperative_grad._MIN_AGGREGATE_BYTES = 10000
- _ = backprop.implicit_grad(fn)()
- self.assertEqual(len(add_n), 2)
-
- imperative_grad._MIN_AGGREGATE_COUNT = old_cnt
- imperative_grad._MIN_AGGREGATE_BYTES = old_bytes
- # pylint: enable=protected-access
- context.context().clear_post_execution_callbacks()
-
def testImplicitGradientsCustomGradientAndCachedVariableValue(self):
@custom_gradient.custom_gradient