diff options
author | 2017-10-17 09:43:46 -0700 | |
---|---|---|
committer | 2017-10-17 09:47:39 -0700 | |
commit | f8b3ced20f7063b3c8efb0e691f28bef845a05f6 (patch) | |
tree | dcac3a108af41986b988756e82d84ede8c8f90c7 /tensorflow/python/eager/backprop_test.py | |
parent | a86a589c8b329176bfbb64552405644cb641d99e (diff) |
Reworks the imperative_grad interface.
PiperOrigin-RevId: 172477878
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 9083e3a712..2645d542c0 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient +from tensorflow.python.eager import imperative_grad from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -442,21 +443,21 @@ class BackpropTest(test.TestCase): # Reduce the aggregation limit, cause the backprop to do some # early aggregation. # pylint: disable=protected-access - old_cnt = backprop._MIN_AGGREGATE_COUNT - old_bytes = backprop._MIN_AGGREGATE_BYTES - backprop._MIN_AGGREGATE_COUNT = 10 - backprop._MIN_AGGREGATE_BYTES = 1 + old_cnt = imperative_grad._MIN_AGGREGATE_COUNT + old_bytes = imperative_grad._MIN_AGGREGATE_BYTES + imperative_grad._MIN_AGGREGATE_COUNT = 10 + imperative_grad._MIN_AGGREGATE_BYTES = 1 _ = backprop.implicit_grad(fn)() self.assertEqual(len(add_n), 6) del add_n[:] # Aggregation is also limited by the memory. - backprop._MIN_AGGREGATE_BYTES = 10000 + imperative_grad._MIN_AGGREGATE_BYTES = 10000 _ = backprop.implicit_grad(fn)() self.assertEqual(len(add_n), 2) - backprop._MIN_AGGREGATE_COUNT = old_cnt - backprop._MIN_AGGREGATE_BYTES = old_bytes + imperative_grad._MIN_AGGREGATE_COUNT = old_cnt + imperative_grad._MIN_AGGREGATE_BYTES = old_bytes # pylint: enable=protected-access context.context().clear_post_execution_callbacks() |