diff options
author | 2018-04-12 11:54:21 -0700 | |
---|---|---|
committer | 2018-04-12 11:56:34 -0700 | |
commit | 583ee0eabfb1bebd0eb533d2ab7a5c17af7e664e (patch) | |
tree | a81346e0dfcb743104e72d8ab9c0dfdc2e8f84b4 | |
parent | 454a22aa29dc2dba355094aabe733cd8419f2788 (diff) |
Add testCompileTimeConstantsInDefun in xla
PiperOrigin-RevId: 192646199
-rw-r--r-- | tensorflow/compiler/tests/function_test.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 11d8a99ffe..fbc3c994d1 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -105,6 +105,28 @@ class FunctionTest(XLATestCase): result = sess.run(call_f) self.assertAllClose(result, expected, rtol=1e-3) + def testCompileTimeConstantsInDefun(self): + """Tests that XLA handles compile-time constants in defuns.""" + with self.test_session() as sess: + + @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) + def Foo(a, c, d): + # c and d must be known at compile time + x = array_ops.slice(a, c, d) + return x + + a = array_ops.placeholder(dtypes.float32) + c = array_ops.placeholder(dtypes.int32, shape=[4]) + d = array_ops.placeholder(dtypes.int32, shape=[4]) + with self.test_scope(): + call_f = Foo(a, c, d) + result = sess.run(call_f, feed_dict={ + a: np.ones([1, 4, 4, 1]), + c: [0, 0, 0, 0], + d: [1, 2, 2, 1]}) + + self.assertAllEqual(np.ones([1, 2, 2, 1]), result) + # TODO(b/36139787): Re-enable this test when noinline works again. def DISABLED_testFunctionsNoInline(self): |