From 583ee0eabfb1bebd0eb533d2ab7a5c17af7e664e Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Thu, 12 Apr 2018 11:54:21 -0700 Subject: Add testCompileTimeConstantsInDefun in xla PiperOrigin-RevId: 192646199 --- tensorflow/compiler/tests/function_test.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 11d8a99ffe..fbc3c994d1 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -105,6 +105,28 @@ class FunctionTest(XLATestCase): result = sess.run(call_f) self.assertAllClose(result, expected, rtol=1e-3) + def testCompileTimeConstantsInDefun(self): + """Tests that XLA handles compile-time constants in defuns.""" + with self.test_session() as sess: + + @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) + def Foo(a, c, d): + # c and d must be known at compile time + x = array_ops.slice(a, c, d) + return x + + a = array_ops.placeholder(dtypes.float32) + c = array_ops.placeholder(dtypes.int32, shape=[4]) + d = array_ops.placeholder(dtypes.int32, shape=[4]) + with self.test_scope(): + call_f = Foo(a, c, d) + result = sess.run(call_f, feed_dict={ + a: np.ones([1, 4, 4, 1]), + c: [0, 0, 0, 0], + d: [1, 2, 2, 1]}) + + self.assertAllEqual(np.ones([1, 2, 2, 1]), result) + # TODO(b/36139787): Re-enable this test when noinline works again. def DISABLED_testFunctionsNoInline(self): -- cgit v1.2.3