aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-20 15:37:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 15:45:36 -0700
commit4d39844c1dafb6b74ad49b231bc949a2e026f5ea (patch)
tree0f3043c6551f6be17026ceafe62b8dcffc026da8 /tensorflow/compiler/tests
parentd78b3484d4b98790c2d3a7c0d861487e2fcdefdf (diff)
Split XlaLaunch into XlaCompile and XlaRun; NFC
This CL splits the functionality in XlaLaunch into two separate operations: - XlaCompile, responsible for compiling a TF function into a LocalExecutable - XlaRun, responsible for executing a LocalExecutable created by XlaCompile This CL is a stepping stone towards implementing lazy compilation for TF/XLA. The XlaCompile op is spec'ed to return a boolean indicating whether the compilation was successful. Right now that boolean is always set to true by XlaCompile and its value is otherwise ignored, but in the future it will be used to indicate whether the TF function was compiled or not, and thus whether we should execute XlaRun or just directly call the TF function. XlaLaunch still exists, and will be created by create_xla_launch_op.cc. In the future we may consider removing it altogether. build_xla_launch_ops.cc, now renamed to build_xla_ops.cc, creates a XlaCompile/XlaRun pair instead of XlaLaunch. This CL is organized as follows: - jit/ops/xla_ops.cc gets two new XLA-specific operations, XlaCompile and XlaRun, described above. XlaRun redundantly takes the must-be-constant inputs to the TensorFlow cluster to keep the implementation simple (simple in the sense of similar to XlaLaunch), but I will remove this in a subsequent cleanup CL. - jit/kernels/xla_ops.cc implements XlaCompile and XlaRun in a fairly straightforward manner. XlaCompile compiles the TF function, puts it in a process-global storage, XlaExecutableClosureStore, and produces a int64 key. XlaRun uses the key to read out the LocalExecutable and execute it. I'm not sure if XlaExecutableClosureStore should be a resource like XlaCompilationCache; I did not immediately see any reason to make it so. - There are changes to the various _device files to register XlaCompile and XlaRun for the XLA_* devices. - Finally, I had to fix some tests that were expecting XlaLaunch in the execution timeline. PiperOrigin-RevId: 213895405
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/dense_layer_test.py25
-rw-r--r--tensorflow/compiler/tests/jit_test.py48
2 files changed, 42 insertions, 31 deletions
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 0af74c2d8f..9390870e07 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -45,17 +45,21 @@ def InLabels(labels, substr):
return any([substr in x for x in labels])
-def XlaLaunchOpCount(labels):
- """Count how many XlaLaunch labels are present."""
- return sum("XlaLaunch(" in x for x in labels)
+class DenseLayerTest(test.TestCase):
+ def countXlaOps(self, labels):
+ """Count how many XlaCompile/XlaRun labels are present."""
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+ return xla_run_count
-class DenseLayerTest(test.TestCase):
def testDenseLayerAutoJit(self):
"""Tests dense layer compilation in auto-jit mode.
- Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
+ Dense layer should be compiled into a single XlaCompile/XlaRun op pair in
+ auto-jit mode.
"""
os.environ["TF_XLA_FLAGS"] = (
@@ -77,14 +81,14 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer with static shape input tensor should be compiled into a single
- XlaLaunch op by XLA.
+ XlaCompile/XlaRun op pair by XLA.
"""
with self.cached_session() as sess:
@@ -101,7 +105,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
# No need to check whether ListDiff is compiled or not because ListDiff op
# is not used when input tensor shape is fully defined.
@@ -111,7 +115,8 @@ class DenseLayerTest(test.TestCase):
Dense layer uses shape op to get shape of input tensor if its shape is not
fully defined. XLA does not cluster shape op with other operators. But in
experimental_jit_scope, XLA is forced to compile shape op into its own
- cluster, causing dense layer to be split into TWO XlaLaunch ops.
+ cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op
+ pairs.
"""
with self.cached_session() as sess:
@@ -128,7 +133,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(2, XlaLaunchOpCount(labels))
+ self.assertEqual(2, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 0839fb123e..de68ff0e32 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -77,11 +77,11 @@ def InLabels(labels, substr):
return any([substr in x for x in labels])
-def MetadataHasXlaLaunch(run_metadata):
- """Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
+def MetadataHasXlaOp(run_metadata):
+ """Returns true if there are XlaRun kernels in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
- return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
+ return InLabels(RunMetadataLabels(run_metadata), "XlaRun")
class JitLaunchTest(test.TestCase):
@@ -90,9 +90,10 @@ class JitLaunchTest(test.TestCase):
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
- # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
- # actually ran. However, it is sometimes possible for XlaLaunch ops to be
- # constant-folded away, so the check is optional.
+ #
+ # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
+ # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
+ # ops to be constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
placeholders = []
@@ -115,7 +116,7 @@ class JitLaunchTest(test.TestCase):
print("Compiled Result {}".format(compiled))
if require_kernel_launch:
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
direct = sess.run(direct_op, feeds)
print("Direct Result {}".format(direct))
@@ -149,10 +150,10 @@ class JitLaunchTest(test.TestCase):
y = math_ops.add(x, x)
return y, y
- # Exercises compling a function (say, Foo) which calls another
- # function (say, Bar) which is not inlined. When the compiler compiles
- # Foo, it needs to symbolic execute Bar correctly regardless whether
- # Bar is inlined or not.
+ # Exercises compiling a function (say, Foo) which calls another function
+ # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
+ # to symbolically execute Bar correctly regardless of whether Bar is inlined
+ # or not.
# TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
@@ -259,7 +260,7 @@ class JitLaunchTest(test.TestCase):
# TODO(phawkins): really we would like to test that there were exactly
# two kernel launches. However, we have no reliable way to determine
# that.
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
expected = np.square(np.dot(dx, dw) + db)
self.assertAllClose(expected, output, rtol=1e-1)
@@ -289,7 +290,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
def testIgnoredArguments(self):
@@ -313,7 +314,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(28, out)
def testLoops(self):
@@ -331,7 +332,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(95), rtol=1e-1)
def testCond(self):
@@ -356,7 +357,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(6), rtol=1e-1)
def testNestedFunction(self):
@@ -441,14 +442,16 @@ class XlaCompilationTest(test.TestCase):
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
- self.assertFalse(InLabels(labels, "XlaLaunch"))
+ self.assertFalse(InLabels(labels, "XlaCompile"))
+ self.assertFalse(InLabels(labels, "XlaRun"))
- # Compile the backprop. One XlaLaunch.
+ # Compile the backprop. One XlaCompile/XlaRun pair.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
- self.assertTrue(InLabels(labels, "XlaLaunch"))
+ self.assertTrue(InLabels(labels, "XlaCompile"))
+ self.assertTrue(InLabels(labels, "XlaRun"))
class ElementWiseFusionTest(test.TestCase):
@@ -482,9 +485,12 @@ class ElementWiseFusionTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = RunMetadataLabels(run_metadata)
- count = sum("XlaLaunch(" in x for x in labels)
- return output, count
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+
+ return output, xla_run_count
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)