aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/jit_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/jit_test.py')
-rw-r--r--tensorflow/compiler/tests/jit_test.py459
1 files changed, 459 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
new file mode 100644
index 0000000000..8a568d6d58
--- /dev/null
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -0,0 +1,459 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for JIT compilation on the CPU and GPU devices."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.compiler import jit
+from tensorflow.core.framework import function_pb2
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+jit_scope = jit.experimental_jit_scope
+
+
+def CompiledKernel(fn, *inputs, **kwargs):
+ """Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
+ name = kwargs.pop("name", None)
+ noinline = kwargs.pop("noinline", None)
+
+ @function.Defun(func_name=name, noinline=noinline, compiled=True)
+ def Compiled(*args):
+ return fn(*args)
+
+ return Compiled(*inputs)
+
+
+def RunMetadataLabels(run_metadata):
+ """Returns all labels in run_metadata."""
+ labels = []
+ for dev_stats in run_metadata.step_stats.dev_stats:
+ for node_stats in dev_stats.node_stats:
+ labels.append(node_stats.timeline_label)
+ return labels
+
+
+def InLabels(labels, substr):
+ """Returns true iff one of the labels contains substr."""
+ return any([substr in x for x in labels])
+
+
+def MetadataHasXlaLaunch(run_metadata):
+ """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline."""
+
+ # TODO(phawkins): find a less hacky way to test whether a kernel ran.
+ return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch")
+
+
+class JitLaunchTest(test.TestCase):
+
+ # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel.
+ # Verifies that the outputs match and that XLA was invoked. 'fn' must take
+ # the same number of tensors as arguments that are in 'args', and must return
+ # a tuple of output tensors.
+ # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node
+ # actually ran. However, it is sometimes possible for _XlaLaunch ops to be
+ # constant-folded away, so the check is optional.
+ def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
+ with session_lib.Session() as sess:
+ placeholders = []
+ feeds = {}
+ for arg in args:
+ placeholder = array_ops.placeholder(
+ dtypes.as_dtype(arg.dtype), list(arg.shape))
+ placeholders.append(placeholder)
+ feeds[placeholder] = arg
+
+ compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline)
+ direct_op = fn(*placeholders)
+
+ run_metadata = config_pb2.RunMetadata()
+ compiled = sess.run(compiled_op,
+ feeds,
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ print("Compiled Result {}".format(compiled))
+
+ if require_kernel_launch:
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+
+ direct = sess.run(direct_op, feeds)
+ print("Direct Result {}".format(direct))
+
+ if (isinstance(compiled, (tuple, list)) and
+ (isinstance(direct, (tuple, list)))):
+ for (x, y) in zip(compiled, direct):
+ self.assertAllClose(x, y, rtol=1e-1)
+ else:
+ self.assertAllClose(compiled, direct)
+
+ def testNoOutputs(self):
+ with session_lib.Session() as sess:
+ # Build a function with a single Const node, whose output is ignored.
+ fdef = function_pb2.FunctionDef()
+ fdef.signature.name = "KernelWithNoOutputs"
+ node = node_def_pb2.NodeDef()
+ node.op = "Const"
+ node.name = "ignored"
+ node.attr["dtype"].type = dtypes.int32.as_datatype_enum
+ tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[])
+ node.attr["value"].tensor.CopyFrom(tensor)
+ fdef.node_def.extend([node])
+
+ # Check that calling the result as a compiled kernel doesn't crash.
+ @function.Defun(compiled=True)
+ def KernelWithNoOutputs():
+ return constant_op.constant(100)
+
+ # Hack to override the definition. By accessing .definition, we
+ # force the _DefinedFunction initialized internally. Then, we
+ # replace it's internal FunctionDef proto. We do this hack here
+ # because one typically can't construct KernelWithNoOutputs
+ # function via Defun decorator directly.
+ _ = KernelWithNoOutputs.definition
+ foo = KernelWithNoOutputs
+ foo._definition = fdef
+ call = KernelWithNoOutputs()
+ sess.run(call, {})
+
+ def testAliasing(self):
+ """Regression test for compiled functions that return an aliased buffer.
+
+ XLA returns aliased buffers if outputs are identical. Tests that
+ we handle that case.
+ """
+
+ def AddOnceReturnTwice(x):
+ y = math_ops.add(x, x)
+ return y, y
+
+ # Exercises compling a function (say, Foo) which calls another
+ # function (say, Bar) which is not inlined. When the compiler compiles
+ # Foo, it needs to symbolic execute Bar correctly regardless whether
+ # Bar is inlined or not.
+ #
+ # Tests compiled=True and noinline=True.
+ self._compare(
+ AddOnceReturnTwice, [np.array(
+ [[[0.5, -1.0]]], dtype=np.float32)],
+ noinline=True)
+ # Tests compiled=True and noinline=False.
+ self._compare(
+ AddOnceReturnTwice, [np.array(
+ [[[0.5, -1.0]]], dtype=np.float32)],
+ noinline=False)
+
+ def testOneConstOutput(self):
+ """Test consisting of a single constant return value."""
+
+ def OneConstOutput():
+ return constant_op.constant([-3, 44, 99])
+
+ self._compare(OneConstOutput, [], require_kernel_launch=False)
+
+ def testConstZeroElementOutput(self):
+ """Test consisting of a constant zero element return value."""
+
+ def ConstZeroElementOutput():
+ return array_ops.fill([7, 0], 3.0)
+
+ self._compare(ConstZeroElementOutput, [], require_kernel_launch=False)
+
+ def testSomeConstOutputs(self):
+ """Test kernels that return a mixture of const and non-const outputs."""
+
+ def SomeConstOutputs(x):
+ return constant_op.constant(
+ [-2, 7]), array_ops.identity(x), constant_op.constant(3.5)
+
+ self._compare(
+ SomeConstOutputs, [np.array(
+ [[1, 2, 3], [4, 5, 6]], dtype=np.float32)])
+
+ def testInt32Input(self):
+ """Test an int32-typed input.
+
+ On a GPU, int32 tensors will be placed in host memory.
+ """
+
+ def AddToSelf(x):
+ return math_ops.add(x, x)
+
+ self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)])
+
+ def testMandatoryConstantInput(self):
+ """Tests an operator that has a mandatory-constant shape input."""
+
+ def FillWithFloat(x):
+ return array_ops.fill(x, 9.5)
+
+ self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)])
+
+ def testMnistForwardFunc(self):
+ """Compute inference function from MNIST beginners tutorial."""
+ batch_size = 16
+ image_size = 28 * 28
+ num_classes = 10
+
+ # Define a TensorFlow function to compute the forward pass.
+ def MnistForward(w, b, x):
+ return nn_ops.softmax(math_ops.matmul(x, w) + b)
+
+ w = np.random.random_sample((image_size, num_classes)).astype(np.float32)
+ b = np.random.random_sample((num_classes)).astype(np.float32)
+ x = np.random.random_sample((batch_size, image_size)).astype(np.float32)
+ self._compare(MnistForward, [w, b, x])
+
+ def testExplicitMarking(self):
+ """Test explicit marking of operators to compile."""
+ batch_size = 16
+ image_size = 28 * 28
+ num_classes = 10
+
+ with ops.Graph().as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ w = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ with jit_scope():
+ y1 = math_ops.matmul(x, w)
+ y2 = math_ops.add(y1, b)
+ with jit_scope():
+ y = math_ops.square(y2)
+
+ dw = np.random.random_sample((image_size, num_classes)).astype(np.float32)
+ db = np.random.random_sample((num_classes)).astype(np.float32)
+ dx = np.random.random_sample((batch_size, image_size)).astype(np.float32)
+ with session_lib.Session() as sess:
+ run_metadata = config_pb2.RunMetadata()
+ output = sess.run(y, {x: dx,
+ w: dw,
+ b: db},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+
+ # TODO(phawkins): really we would like to test that there were exactly
+ # two kernel launches. However, we have no reliable way to determine
+ # that.
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+
+ expected = np.square(np.dot(dx, dw) + db)
+ self.assertAllClose(expected, output, rtol=1e-1)
+
+
+class XlaCompilationTest(test.TestCase):
+ """Tests for auto-compilation on CPU/GPU devices."""
+
+ def testReshape(self):
+ """Tests an operator with compile-time constant and non-constant inputs."""
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32)
+ y = array_ops.placeholder(dtypes.int32)
+ with jit_scope():
+ # Reshape's first argument is non-constant in the JIT, but its second
+ # (shape) argument will be treated as a compile-time constant for
+ # each JIT compilation.
+ # We do not use a tf.const() argument since we want to ensure the
+ # shape is still a run-time argument to the JIT, and not
+ # statically known as part of the JIT compilation's input graph.
+ z = array_ops.reshape(x, y)
+ run_metadata = config_pb2.RunMetadata()
+ out = sess.run(z,
+ {x: np.array([1, 2, 3, 4, 5, 6], np.float32),
+ y: [-1, 3]},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
+
+ def testIgnoredArguments(self):
+ """Tests that JIT computations can ignore formal parameters."""
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.int32)
+ y = array_ops.placeholder(dtypes.int32)
+ with jit_scope():
+ z = math_ops.add(x, x)
+ w = math_ops.add(y, y)
+ # Pulls 'w' into the same compilation via control dependencies.
+ with ops.control_dependencies([w]):
+ n = control_flow_ops.no_op()
+ with ops.control_dependencies([n]):
+ t = math_ops.add(z, z)
+
+ run_metadata = config_pb2.RunMetadata()
+ out = sess.run(t, {x: np.int32(7),
+ y: np.int32(404)},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assertAllClose(28, out)
+
+ def testLoops(self):
+ """Tests that compilation accepts computations containing loops."""
+
+ with self.test_session() as session:
+ x = array_ops.placeholder(dtypes.float32)
+ with jit_scope():
+ c = lambda i, _: math_ops.less(i, 5)
+ b = lambda i, x: (i + 1, x * 2.0 + 1.0)
+ _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
+
+ run_metadata = config_pb2.RunMetadata()
+ result = session.run(y, {x: np.float32(2)},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assertAllClose(result, np.float32(95), rtol=1e-1)
+
+ def testCond(self):
+ """Tests that compilation handles switch operators."""
+
+ with self.test_session() as session:
+ x = array_ops.placeholder(dtypes.float32)
+ y = array_ops.placeholder(dtypes.float32)
+ c = array_ops.placeholder(dtypes.bool)
+ with jit_scope():
+ z = x + 1.0
+ w = control_flow_ops.cond(c, lambda: z, lambda: y)
+ t = math_ops.add(z, w)
+
+ # If JIT compilation chooses to cluster z and t, then execution will
+ # deadlock.
+
+ run_metadata = config_pb2.RunMetadata()
+ result = session.run(t, {x: np.float32(2),
+ y: np.float32(4),
+ c: True},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assertAllClose(result, np.float32(6), rtol=1e-1)
+
+ def testNestedFunction(self):
+ g = ops.Graph()
+ with g.as_default():
+
+ @function.Defun(compiled=True)
+ def Bar(x, y):
+ return x + 2 * y
+
+ @function.Defun(compiled=True)
+ def Foo(x):
+ return Bar(x * x, x * x * x)
+
+ @function.Defun()
+ def Entry(x):
+ return Foo(x)
+
+ inp = array_ops.placeholder(dtypes.float32)
+ out = Entry(inp)
+
+ with self.test_session(graph=g, use_gpu=True) as sess:
+ run_metadata = config_pb2.RunMetadata()
+ val = sess.run(out,
+ feed_dict={inp: [2., 10.]},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assertAllClose(val, [20., 2100.])
+
+ def testLoopDeadlock(self):
+ """Regression test for bug that caused deadlocks in graphs with loops."""
+
+ with self.test_session() as session:
+ x = array_ops.placeholder(dtypes.float32)
+ with jit_scope():
+ y = x + 1.0
+ c = lambda i, _x, _y: math_ops.less(i, 5)
+ b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0)
+ _, _, w = control_flow_ops.while_loop(c, b,
+ (constant_op.constant(0), y, x))
+ u = w + y
+ result = session.run(u, {x: np.float32(2)})
+ self.assertAllClose(result, np.float32(63), rtol=1e-1)
+
+ def testGradient(self):
+ """Tests that the backprop function is properly compiled."""
+
+ def _Run(compiled):
+
+ @function.Defun(compiled=compiled)
+ def Forward(x):
+ return math_ops.log(x)
+
+ g = ops.Graph()
+ with g.as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ y = Forward(x)
+ dx, = gradients_impl.gradients(y, [x], 1.0)
+
+ cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L1,
+ do_function_inlining=True)))
+ with session_lib.Session(graph=g, config=cfg) as sess:
+ run_metadata = config_pb2.RunMetadata()
+ dx_val = sess.run(dx,
+ feed_dict={x: 100.},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assertAllClose(dx_val, 0.01)
+ return RunMetadataLabels(run_metadata)
+
+ # SymGrad[f=log(x)](x, dy) = 1/x * dy
+ #
+ # Note: we don't need to compute log(x) for dx due to graph pruning.
+
+ # Do not compile the backprop. We should see one Reciprocal and one Mul.
+ labels = _Run(compiled=False)
+ self.assertFalse(InLabels(labels, "Log"))
+ self.assertTrue(InLabels(labels, "Reciprocal"))
+ self.assertTrue(InLabels(labels, "Mul"))
+ self.assertFalse(InLabels(labels, "_XlaLaunch"))
+
+ # Compile the backprop. One _XlaLaunch.
+ labels = _Run(compiled=True)
+ self.assertFalse(InLabels(labels, "Log"))
+ self.assertFalse(InLabels(labels, "Reciprocal"))
+ self.assertFalse(InLabels(labels, "Mul"))
+ self.assertTrue(InLabels(labels, "_XlaLaunch"))
+
+
+if __name__ == "__main__":
+ test.main()