aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py56
1 files changed, 30 insertions, 26 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index bdda200ff6..3d3f54b9c4 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -96,6 +96,19 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(grads_and_vars[0][0], 1.0)
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
+ def testGradientInsideLoop(self):
+ with ops.Graph().as_default():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ def body(_):
+ _ = v + 1.0 # This reads the variable inside the loop context
+ with backprop.GradientTape() as t:
+ result = v * 2
+ self.assertTrue(t.gradient(result, v) is not None)
+ return 1.0
+
+ control_flow_ops.while_loop(lambda i: False, body, [1.0])
+
def testWhereGradient(self):
# Note: where is special because only some of its arguments are of
# differentiable dtypes.
@@ -912,32 +925,23 @@ class BackpropTest(test.TestCase):
'did you forget to return a value from fn?'):
val_and_grads_fn(x, y)
- def testZerosCacheDoesntLeakAcrossModes(self):
- with ops.Graph().as_default():
- t = random_ops.random_normal(shape=[100, 2])
- x = random_ops.random_normal(shape=[100, 4])
- dy = random_ops.random_normal(shape=[100, 4])
- with backprop.GradientTape() as gradient_tape:
- gradient_tape.watch(x)
- x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
- y1 = x1 ** 2.
- y = array_ops.concat([y1, t], axis=1)
-
- dx = gradient_tape.gradient(y, x, output_gradients=dy)
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(dx)
-
- t = random_ops.random_normal(shape=[100, 2])
- x = random_ops.random_normal(shape=[100, 4])
- dy = random_ops.random_normal(shape=[100, 4])
- with backprop.GradientTape() as gradient_tape:
- gradient_tape.watch(x)
- x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
- y1 = x1 ** 2.
- y = array_ops.concat([y1, t], axis=1)
-
- dx = gradient_tape.gradient(y, x, output_gradients=dy)
+ def testZerosCacheDoesntLeakAcrossGraphs(self):
+ with context.graph_mode():
+ def get_grad():
+ with ops.Graph().as_default(), self.test_session():
+ t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
+ x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
+ with backprop.GradientTape() as gt:
+ tape.watch(x)
+ x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
+ y1 = x1**2
+ y = array_ops.concat([y1, t], axis=1)
+ return self.evaluate(gt.gradient(y, x))
+
+ grad1 = get_grad()
+ grad2 = get_grad()
+
+ self.assertAllEqual(grad1, grad2)
if __name__ == '__main__':