aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-07-24 13:50:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 13:59:53 -0700
commit57d051e7b156313c0beef6eb1fd9e6ca955a568a (patch)
tree9a47274bb3c455a0edda7046d60bd48750badbad /tensorflow/python/eager/backprop_test.py
parentee0bd6ef450b388fadea63b31b65b13bd12f17d6 (diff)
Don't cache zero tensors in graph at all
PiperOrigin-RevId: 205885372
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py43
1 files changed, 17 insertions, 26 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 95a3a8b629..3d3f54b9c4 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -925,32 +925,23 @@ class BackpropTest(test.TestCase):
'did you forget to return a value from fn?'):
val_and_grads_fn(x, y)
- def testZerosCacheDoesntLeakAcrossModes(self):
- with ops.Graph().as_default():
- t = random_ops.random_normal(shape=[100, 2])
- x = random_ops.random_normal(shape=[100, 4])
- dy = random_ops.random_normal(shape=[100, 4])
- with backprop.GradientTape() as gradient_tape:
- gradient_tape.watch(x)
- x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
- y1 = x1 ** 2.
- y = array_ops.concat([y1, t], axis=1)
-
- dx = gradient_tape.gradient(y, x, output_gradients=dy)
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(dx)
-
- t = random_ops.random_normal(shape=[100, 2])
- x = random_ops.random_normal(shape=[100, 4])
- dy = random_ops.random_normal(shape=[100, 4])
- with backprop.GradientTape() as gradient_tape:
- gradient_tape.watch(x)
- x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
- y1 = x1 ** 2.
- y = array_ops.concat([y1, t], axis=1)
-
- dx = gradient_tape.gradient(y, x, output_gradients=dy)
+ def testZerosCacheDoesntLeakAcrossGraphs(self):
+ with context.graph_mode():
+ def get_grad():
+ with ops.Graph().as_default(), self.test_session():
+ t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
+ x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
+ with backprop.GradientTape() as gt:
+ tape.watch(x)
+ x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
+ y1 = x1**2
+ y = array_ops.concat([y1, t], axis=1)
+ return self.evaluate(gt.gradient(y, x))
+
+ grad1 = get_grad()
+ grad2 = get_grad()
+
+ self.assertAllEqual(grad1, grad2)
if __name__ == '__main__':