aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-04-12 11:54:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 11:56:34 -0700
commit583ee0eabfb1bebd0eb533d2ab7a5c17af7e664e (patch)
treea81346e0dfcb743104e72d8ab9c0dfdc2e8f84b4
parent454a22aa29dc2dba355094aabe733cd8419f2788 (diff)
Add testCompileTimeConstantsInDefun in xla
PiperOrigin-RevId: 192646199
-rw-r--r--tensorflow/compiler/tests/function_test.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
index 11d8a99ffe..fbc3c994d1 100644
--- a/tensorflow/compiler/tests/function_test.py
+++ b/tensorflow/compiler/tests/function_test.py
@@ -105,6 +105,28 @@ class FunctionTest(XLATestCase):
result = sess.run(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
+ def testCompileTimeConstantsInDefun(self):
+ """Tests that XLA handles compile-time constants in defuns."""
+ with self.test_session() as sess:
+
+ @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32)
+ def Foo(a, c, d):
+ # c and d must be known at compile time
+ x = array_ops.slice(a, c, d)
+ return x
+
+ a = array_ops.placeholder(dtypes.float32)
+ c = array_ops.placeholder(dtypes.int32, shape=[4])
+ d = array_ops.placeholder(dtypes.int32, shape=[4])
+ with self.test_scope():
+ call_f = Foo(a, c, d)
+ result = sess.run(call_f, feed_dict={
+ a: np.ones([1, 4, 4, 1]),
+ c: [0, 0, 0, 0],
+ d: [1, 2, 2, 1]})
+
+ self.assertAllEqual(np.ones([1, 2, 2, 1]), result)
+
# TODO(b/36139787): Re-enable this test when noinline works again.
def DISABLED_testFunctionsNoInline(self):