aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-28 11:25:10 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commitab7b9eeb59bf5849650ebe853fb387b028756188 (patch)
tree1088407e77e288b337c6db0c7a50e4e5c53cb31a /tensorflow/python/eager/backprop_test.py
parentf599f959a862797af923d4ce50e59133c8e89e46 (diff)
Include eager/graph mode in cache key so that one type of tensor doesn't spill
into the other. PiperOrigin-RevId: 202513508
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index e129c2756a..ebbd3cd98e 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -900,6 +900,33 @@ class BackpropTest(test.TestCase):
'did you forget to return a value from fn?'):
val_and_grads_fn(x, y)
+ def testZerosCacheDoesntLeakAcrossModes(self):
+ with ops.Graph().as_default():
+ t = random_ops.random_normal(shape=[100, 2])
+ x = random_ops.random_normal(shape=[100, 4])
+ dy = random_ops.random_normal(shape=[100, 4])
+ with backprop.GradientTape() as gradient_tape:
+ gradient_tape.watch(x)
+ x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
+ y1 = x1 ** 2.
+ y = array_ops.concat([y1, t], axis=1)
+
+ dx = gradient_tape.gradient(y, x, output_gradients=dy)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(dx)
+
+ t = random_ops.random_normal(shape=[100, 2])
+ x = random_ops.random_normal(shape=[100, 4])
+ dy = random_ops.random_normal(shape=[100, 4])
+ with backprop.GradientTape() as gradient_tape:
+ gradient_tape.watch(x)
+ x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
+ y1 = x1 ** 2.
+ y = array_ops.concat([y1, t], axis=1)
+
+ dx = gradient_tape.gradient(y, x, output_gradients=dy)
+
if __name__ == '__main__':
test.main()