aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-17 09:43:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-17 09:47:39 -0700
commitf8b3ced20f7063b3c8efb0e691f28bef845a05f6 (patch)
treedcac3a108af41986b988756e82d84ede8c8f90c7 /tensorflow/python/eager/backprop_test.py
parenta86a589c8b329176bfbb64552405644cb641d99e (diff)
Reworks the imperative_grad interface.
PiperOrigin-RevId: 172477878
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 9083e3a712..2645d542c0 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
+from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
@@ -442,21 +443,21 @@ class BackpropTest(test.TestCase):
# Reduce the aggregation limit, cause the backprop to do some
# early aggregation.
# pylint: disable=protected-access
- old_cnt = backprop._MIN_AGGREGATE_COUNT
- old_bytes = backprop._MIN_AGGREGATE_BYTES
- backprop._MIN_AGGREGATE_COUNT = 10
- backprop._MIN_AGGREGATE_BYTES = 1
+ old_cnt = imperative_grad._MIN_AGGREGATE_COUNT
+ old_bytes = imperative_grad._MIN_AGGREGATE_BYTES
+ imperative_grad._MIN_AGGREGATE_COUNT = 10
+ imperative_grad._MIN_AGGREGATE_BYTES = 1
_ = backprop.implicit_grad(fn)()
self.assertEqual(len(add_n), 6)
del add_n[:]
# Aggregation is also limited by the memory.
- backprop._MIN_AGGREGATE_BYTES = 10000
+ imperative_grad._MIN_AGGREGATE_BYTES = 10000
_ = backprop.implicit_grad(fn)()
self.assertEqual(len(add_n), 2)
- backprop._MIN_AGGREGATE_COUNT = old_cnt
- backprop._MIN_AGGREGATE_BYTES = old_bytes
+ imperative_grad._MIN_AGGREGATE_COUNT = old_cnt
+ imperative_grad._MIN_AGGREGATE_BYTES = old_bytes
# pylint: enable=protected-access
context.context().clear_post_execution_callbacks()