aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-05-16 12:50:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 12:53:29 -0700
commit37ec69856d18b669a1d63d0a39f78f22f97b1148 (patch)
tree2ed910a26866a19092e3f96f4fc6e1609cdb97dd
parentea3f7d1947c8a379557387b948affd918f186c41 (diff)
Add a test for compiled tfe.defun in GradientTape
PiperOrigin-RevId: 196873235
-rw-r--r--tensorflow/compiler/tests/eager_test.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 5ab1585f8c..311f2ada15 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -234,6 +234,23 @@ class EagerFunctionTest(XLATestCase):
self.assertAllEqual([[1.]], c.numpy())
self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
+ def testDefunInGradientTape(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+
+ @function.defun(compiled=True)
+ def f(x):
+ x = v0 * v0 * x
+ return x
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ y = f(x)
+ dy = tape.gradient(y, v0)
+
+ self.assertEqual(75, y.numpy())
+ self.assertEqual(30, dy.numpy())
+
if __name__ == '__main__':
ops.enable_eager_execution(