diff options
author | Igor Ganichev <iga@google.com> | 2018-10-09 15:07:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 15:12:12 -0700 |
commit | 5f69248a692f7b47ea11930621f4f19d0397fe8c (patch) | |
tree | e6ae69c17d798afc96ba83644bf2ce6656181856 /tensorflow/python/eager/backprop_test.py | |
parent | c1093a3757224257fed0f7a1959d0fc99d5c757f (diff) |
Make defun work under distributed strategies.
The core of the change is have the gradient tape capture
distributed variables instead of plain ResourceVariables.
In other words, we move the distribution awareness from defun
down to tape and rely on distributed variable magic to provide us
with the right variable at runtime.
In tower context, we always watch the container (e.g. MirroredVariable).
In cross tower context, we always watch all the components.
PiperOrigin-RevId: 216430530
Diffstat (limited to 'tensorflow/python/eager/backprop_test.py')
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 7e5c9f3cb6..b1b20fafd2 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -258,6 +258,30 @@ class BackpropTest(test.TestCase): loss += v * v self.assertAllEqual(t.gradient(loss, v), 2.0) + def testAutomaticWatchedVariables(self): + with backprop.GradientTape() as t: + self.assertEqual(0, len(t.watched_variables())) + v = resource_variable_ops.ResourceVariable(1.0) + loss = v * v + self.assertAllEqual([v], t.watched_variables()) + + t.reset() + self.assertEqual(0, len(t.watched_variables())) + loss += v * v + self.assertAllEqual([v], t.watched_variables()) + + def testExplicitWatchedVariables(self): + with backprop.GradientTape() as t: + self.assertEqual(0, len(t.watched_variables())) + v = resource_variable_ops.ResourceVariable(1.0) + t.watch(v) + self.assertAllEqual([v], t.watched_variables()) + + t.reset() + self.assertEqual(0, len(t.watched_variables())) + t.watch(v) + self.assertAllEqual([v], t.watched_variables()) + @test_util.assert_no_new_tensors def testGradientNone(self): |