diff options
author | 2018-05-16 12:50:41 -0700 | |
---|---|---|
committer | 2018-05-16 12:53:29 -0700 | |
commit | 37ec69856d18b669a1d63d0a39f78f22f97b1148 (patch) | |
tree | 2ed910a26866a19092e3f96f4fc6e1609cdb97dd | |
parent | ea3f7d1947c8a379557387b948affd918f186c41 (diff) |
Add a test for compiled tfe.defun in GradientTape
PiperOrigin-RevId: 196873235
-rw-r--r-- | tensorflow/compiler/tests/eager_test.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 5ab1585f8c..311f2ada15 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -234,6 +234,23 @@ class EagerFunctionTest(XLATestCase): self.assertAllEqual([[1.]], c.numpy()) self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) + def testDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun(compiled=True) + def f(x): + x = v0 * v0 * x + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + if __name__ == '__main__': ops.enable_eager_execution( |