diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-07-24 13:50:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-24 13:59:53 -0700 |
commit | 57d051e7b156313c0beef6eb1fd9e6ca955a568a (patch) | |
tree | 9a47274bb3c455a0edda7046d60bd48750badbad /tensorflow/python/eager/backprop_test.py | |
parent | ee0bd6ef450b388fadea63b31b65b13bd12f17d6 (diff) |
Don't cache zero tensors in graph at all
PiperOrigin-RevId: 205885372
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 43 |
1 files changed, 17 insertions, 26 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 95a3a8b629..3d3f54b9c4 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -925,32 +925,23 @@ class BackpropTest(test.TestCase): 'did you forget to return a value from fn?'): val_and_grads_fn(x, y) - def testZerosCacheDoesntLeakAcrossModes(self): - with ops.Graph().as_default(): - t = random_ops.random_normal(shape=[100, 2]) - x = random_ops.random_normal(shape=[100, 4]) - dy = random_ops.random_normal(shape=[100, 4]) - with backprop.GradientTape() as gradient_tape: - gradient_tape.watch(x) - x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) - y1 = x1 ** 2. - y = array_ops.concat([y1, t], axis=1) - - dx = gradient_tape.gradient(y, x, output_gradients=dy) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(dx) - - t = random_ops.random_normal(shape=[100, 2]) - x = random_ops.random_normal(shape=[100, 4]) - dy = random_ops.random_normal(shape=[100, 4]) - with backprop.GradientTape() as gradient_tape: - gradient_tape.watch(x) - x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) - y1 = x1 ** 2. - y = array_ops.concat([y1, t], axis=1) - - dx = gradient_tape.gradient(y, x, output_gradients=dy) + def testZerosCacheDoesntLeakAcrossGraphs(self): + with context.graph_mode(): + def get_grad(): + with ops.Graph().as_default(), self.test_session(): + t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4)) + x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4)) + with backprop.GradientTape() as gt: + tape.watch(x) + x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) + y1 = x1**2 + y = array_ops.concat([y1, t], axis=1) + return self.evaluate(gt.gradient(y, x)) + + grad1 = get_grad() + grad2 = get_grad() + + self.assertAllEqual(grad1, grad2) if __name__ == '__main__': |