From f8b3ced20f7063b3c8efb0e691f28bef845a05f6 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Tue, 17 Oct 2017 09:43:46 -0700 Subject: Reworks the imperative_grad interface. PiperOrigin-RevId: 172477878 --- tensorflow/python/eager/backprop_test.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'tensorflow/python/eager/backprop_test.py') diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 9083e3a712..2645d542c0 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient +from tensorflow.python.eager import imperative_grad from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -442,21 +443,21 @@ class BackpropTest(test.TestCase): # Reduce the aggregation limit, cause the backprop to do some # early aggregation. # pylint: disable=protected-access - old_cnt = backprop._MIN_AGGREGATE_COUNT - old_bytes = backprop._MIN_AGGREGATE_BYTES - backprop._MIN_AGGREGATE_COUNT = 10 - backprop._MIN_AGGREGATE_BYTES = 1 + old_cnt = imperative_grad._MIN_AGGREGATE_COUNT + old_bytes = imperative_grad._MIN_AGGREGATE_BYTES + imperative_grad._MIN_AGGREGATE_COUNT = 10 + imperative_grad._MIN_AGGREGATE_BYTES = 1 _ = backprop.implicit_grad(fn)() self.assertEqual(len(add_n), 6) del add_n[:] # Aggregation is also limited by the memory. - backprop._MIN_AGGREGATE_BYTES = 10000 + imperative_grad._MIN_AGGREGATE_BYTES = 10000 _ = backprop.implicit_grad(fn)() self.assertEqual(len(add_n), 2) - backprop._MIN_AGGREGATE_COUNT = old_cnt - backprop._MIN_AGGREGATE_BYTES = old_bytes + imperative_grad._MIN_AGGREGATE_COUNT = old_cnt + imperative_grad._MIN_AGGREGATE_BYTES = old_bytes # pylint: enable=protected-access context.context().clear_post_execution_callbacks() -- cgit v1.2.3