diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-06-28 11:25:10 -0700 |
---|---|---|
committer | Gunhan Gulsoy <gunan@google.com> | 2018-06-28 21:37:43 -0700 |
commit | ab7b9eeb59bf5849650ebe853fb387b028756188 (patch) | |
tree | 1088407e77e288b337c6db0c7a50e4e5c53cb31a /tensorflow/python/eager/backprop_test.py | |
parent | f599f959a862797af923d4ce50e59133c8e89e46 (diff) |
Include eager/graph mode in cache key so that one type of tensor doesn't spill
into the other.
PiperOrigin-RevId: 202513508
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index e129c2756a..ebbd3cd98e 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -900,6 +900,33 @@ class BackpropTest(test.TestCase): 'did you forget to return a value from fn?'): val_and_grads_fn(x, y) + def testZerosCacheDoesntLeakAcrossModes(self): + with ops.Graph().as_default(): + t = random_ops.random_normal(shape=[100, 2]) + x = random_ops.random_normal(shape=[100, 4]) + dy = random_ops.random_normal(shape=[100, 4]) + with backprop.GradientTape() as gradient_tape: + gradient_tape.watch(x) + x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) + y1 = x1 ** 2. + y = array_ops.concat([y1, t], axis=1) + + dx = gradient_tape.gradient(y, x, output_gradients=dy) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(dx) + + t = random_ops.random_normal(shape=[100, 2]) + x = random_ops.random_normal(shape=[100, 4]) + dy = random_ops.random_normal(shape=[100, 4]) + with backprop.GradientTape() as gradient_tape: + gradient_tape.watch(x) + x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) + y1 = x1 ** 2. + y = array_ops.concat([y1, t], axis=1) + + dx = gradient_tape.gradient(y, x, output_gradients=dy) + if __name__ == '__main__': test.main() |