aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r--tensorflow/python/eager/backprop_test.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 7e5c9f3cb6..b1b20fafd2 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -258,6 +258,30 @@ class BackpropTest(test.TestCase):
loss += v * v
self.assertAllEqual(t.gradient(loss, v), 2.0)
+ def testAutomaticWatchedVariables(self):
+ with backprop.GradientTape() as t:
+ self.assertEqual(0, len(t.watched_variables()))
+ v = resource_variable_ops.ResourceVariable(1.0)
+ loss = v * v
+ self.assertAllEqual([v], t.watched_variables())
+
+ t.reset()
+ self.assertEqual(0, len(t.watched_variables()))
+ loss += v * v
+ self.assertAllEqual([v], t.watched_variables())
+
+ def testExplicitWatchedVariables(self):
+ with backprop.GradientTape() as t:
+ self.assertEqual(0, len(t.watched_variables()))
+ v = resource_variable_ops.ResourceVariable(1.0)
+ t.watch(v)
+ self.assertAllEqual([v], t.watched_variables())
+
+ t.reset()
+ self.assertEqual(0, len(t.watched_variables()))
+ t.watch(v)
+ self.assertAllEqual([v], t.watched_variables())
+
@test_util.assert_no_new_tensors
def testGradientNone(self):