diff options
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 7e5c9f3cb6..b1b20fafd2 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -258,6 +258,30 @@ class BackpropTest(test.TestCase): loss += v * v self.assertAllEqual(t.gradient(loss, v), 2.0) + def testAutomaticWatchedVariables(self): + with backprop.GradientTape() as t: + self.assertEqual(0, len(t.watched_variables())) + v = resource_variable_ops.ResourceVariable(1.0) + loss = v * v + self.assertAllEqual([v], t.watched_variables()) + + t.reset() + self.assertEqual(0, len(t.watched_variables())) + loss += v * v + self.assertAllEqual([v], t.watched_variables()) + + def testExplicitWatchedVariables(self): + with backprop.GradientTape() as t: + self.assertEqual(0, len(t.watched_variables())) + v = resource_variable_ops.ResourceVariable(1.0) + t.watch(v) + self.assertAllEqual([v], t.watched_variables()) + + t.reset() + self.assertEqual(0, len(t.watched_variables())) + t.watch(v) + self.assertAllEqual([v], t.watched_variables()) + @test_util.assert_no_new_tensors def testGradientNone(self): |