aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-08-27 15:29:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 15:34:18 -0700
commitfc492c08d64f05b1d68beedf11a3b55bf8066a8b (patch)
treee893aaf05e305f5389141735d132e6d50bd9e5ea /tensorflow/compiler/tests
parentdf6c8721f8be706291a8151d725de4435942f7e2 (diff)
Support returning resource handles from function in XLA
There are a couple of reasons to do this: - resource handle are regular tensors part of a public API that can potentially be returned from a function. - When tfe.defun is executed under GradientTape, it generates a function returning resource handles in certain cases. This CL adds support for returning resource handles from an XLA compiled function. These resource handles must have been passed as arguments to the function. In other words, we don't yet support returning resources created inside the function. tfe.defun never makes functions that create resources. PiperOrigin-RevId: 210442856
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/eager_test.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index e32f3d4b7f..63cee550fd 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -351,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase):
var = f(v)
self.assertEqual(2.0, var.numpy())
+ def testReturnResourceHandle(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]])
+
+ def f(v):
+ return v.handle
+
+ f = function.defun(f)
+ handle = f(v)
+ self.assertAllEqual(v.numpy(),
+ resource_variable_ops.read_variable_op(
+ handle, dtypes.float32).numpy())
+
+ def testReturnMultipleResourceHandles(self):
+ with self.test_scope():
+ v1 = resource_variable_ops.ResourceVariable(1.25)
+ v2 = resource_variable_ops.ResourceVariable(2.0)
+
+ def f(v):
+ return v.handle, 3.0 * v, v2.handle, v + v2
+
+ f = function.defun(f)
+ v1_handle, v1_times_3, v2_handle, variable_sum = f(v1)
+ self.assertAllEqual(v1.numpy(),
+ resource_variable_ops.read_variable_op(
+ v1_handle, dtypes.float32).numpy())
+ self.assertEqual(3.75, v1_times_3.numpy())
+ self.assertAllEqual(v2.numpy(),
+ resource_variable_ops.read_variable_op(
+ v2_handle, dtypes.float32).numpy())
+ self.assertEqual(3.25, variable_sum.numpy())
+
def testAllArgumentKinds(self):
"""Test a complex function that takes different argument kinds.
@@ -457,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase):
y = two_x_plus_1(x)
self.assertAllEqual([5, 7, 9], y.numpy())
+ def testNestedDefunWithVariable(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+
+ @function.defun
+ def g(x):
+ x = v0 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ y = f(x)
+
+ self.assertEqual(75, y.numpy())
+
+ def testNestedDefunInGradientTape(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+
+ @function.defun
+ def g(x):
+ x = v0 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ y = f(x)
+ dy = tape.gradient(y, v0)
+
+ self.assertEqual(75, y.numpy())
+ self.assertEqual(30, dy.numpy())
+
+ def testNestedDefunInGradientTapeDifferentVars(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+ v1 = resource_variable_ops.ResourceVariable(3.0)
+
+ @function.defun
+ def g(x):
+ x = v1 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape(persistent=True) as tape:
+ y = f(x)
+ dy_v0 = tape.gradient(y, v0)
+ dy_v1 = tape.gradient(y, v1)
+
+ self.assertEqual(45, y.numpy())
+ self.assertEqual(9, dy_v0.numpy())
+ self.assertEqual(15, dy_v1.numpy())
+
class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.