aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-01-30 13:42:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-30 14:04:51 -0800
commit696dad03880e9b3367b6c4b4c3903d6aa723d7e5 (patch)
tree84486bcd7e4d2bb46b2bac24e62a48be28da21d4 /tensorflow/contrib/compiler
parent44642329df6bc1627c54f92c5f6850e5882da991 (diff)
Allow callable in hidden attr_scope; allow callable in experimental_jit_scope that can determine when an op should be compiled.
Add some unit tests of jit scope that check that compilation can be controlled using this new mechanism. Change: 146034274
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r--tensorflow/contrib/compiler/BUILD25
-rw-r--r--tensorflow/contrib/compiler/jit.py30
-rw-r--r--tensorflow/contrib/compiler/jit_test.py121
3 files changed, 171 insertions, 5 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index bc3b60ef28..2ae33250f2 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -8,6 +8,8 @@ package_group(
packages = ["//tensorflow/..."],
)
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
py_library(
name = "compiler_py",
srcs = [
@@ -21,6 +23,29 @@ py_library(
],
)
+cuda_py_test(
+ name = "jit_test",
+ size = "small",
+ srcs = ["jit_test.py"],
+ additional_deps = [
+ ":compiler_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+ xla_enabled = True,
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/compiler/jit.py b/tensorflow/contrib/compiler/jit.py
index f13c1af885..028b318e70 100644
--- a/tensorflow/contrib/compiler/jit.py
+++ b/tensorflow/contrib/compiler/jit.py
@@ -33,18 +33,38 @@ def experimental_jit_scope(compile_ops=True):
The compilation is a hint and only supported on a best-effort basis.
Example usage:
- with tf.contrib.framework.experimental_jit_scope():
+ with tf.contrib.compiler.experimental_jit_scope():
c = tf.matmul(a, b) # compiled
- with tf.contrib.framework.experimental_jit_scope(compile_ops=False):
- d = tf.matmul(a, c) # not compiled
+ with tf.contrib.compiler.experimental_jit_scope(compile_ops=False):
+ d = tf.matmul(a, c) # not compiled
+ with tf.contrib.compiler.experimental_jit_scope(
+ compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
+ e = tf.matmul(a, b) + d # matmul is compiled, the addition is not.
Args:
- compile_ops: boolean, whether to enable or disable compilation in the scope.
+ compile_ops: Whether to enable or disable compilation in the scope.
+ Either a Python bool, or a callable that accepts the parameter
+ `node_def` and returns a python bool.
Yields:
The current scope, enabling or disabling compilation.
"""
- attrs = {"_XlaCompile": attr_value_pb2.AttrValue(b=compile_ops)}
+ if callable(compile_ops):
+ def xla_compile(node_def):
+ return attr_value_pb2.AttrValue(b=compile_ops(node_def))
+ else:
+ xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
+ attrs = {"_XlaCompile": xla_compile}
+
+ # TODO(ebrevdo): Keep a global XlaScope counter and here create a
+ # special scope that checks if already within a xla scope or creates
+ # a new one with a new scope string. Add a new attr _XlaScope
+ # taking this string. Modify the xla fusion to respect scope
+ # boundaries. Modify gradients_impl to either create a new gradient
+ # scope with a suffix from the fw scope or to try to fuse with
+ # the fw scope of the given op. Should be backwards compatible to
+ # avoid having to modify Defun compilation attributes.
+
# pylint: disable=protected-access
with ops.get_default_graph()._attr_scope(attrs):
yield
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
new file mode 100644
index 0000000000..9b7b74838f
--- /dev/null
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -0,0 +1,121 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for contrib.compiler.jit."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+
+# TODO(keveman): #6568 Remove this hack that makes dlopen() not crash.
+if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
+ import ctypes # pylint: disable=g-import-not-at-top
+ sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
+
+
+# pylint: disable=g-import-not-at-top
+from tensorflow.contrib.compiler import jit
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+# pylint: enable=g-import-not-at-top
+
+
+_REGISTERED_OPS = op_def_registry.get_registered_ops()
+
+
+def enable_jit_nonstateful(node_def):
+ try:
+ return not _REGISTERED_OPS[node_def.op].is_stateful
+ except KeyError:
+ raise ValueError("Unregistered op being created: %s" % node_def)
+
+
+class JITTest(test.TestCase):
+
+ def compute(self, use_jit, compute_fn):
+ random_seed.set_random_seed(1234)
+ with self.test_session(graph=ops.Graph()) as sess:
+ with jit.experimental_jit_scope(use_jit):
+ r = compute_fn()
+ sess.run(variables.global_variables_initializer())
+ return (r, sess.run(r))
+
+ def testJITCreateOpsLambda(self):
+ """Test several ways of customizing the compilation attribute."""
+ def create_ops():
+ with variable_scope.variable_scope(
+ "root",
+ initializer=init_ops.random_uniform_initializer(
+ -0.1, 0.1, seed=2)):
+ inputs = random_ops.random_uniform((1,), seed=1)
+ return inputs
+ v_false_1_t, v_false_1 = self.compute(False, create_ops)
+ _, v_false_2 = self.compute(False, create_ops)
+ v_true_1_t, v_true_1 = self.compute(enable_jit_nonstateful, create_ops)
+ _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops)
+ v_all_true_t, _ = self.compute(True, create_ops)
+ self.assertEqual(False, v_false_1_t.op.get_attr("_XlaCompile"))
+ v_true_1_t_sampler_op = v_true_1_t.graph.get_operation_by_name(
+ "root/random_uniform/RandomUniform")
+ v_all_true_t_sampler_op = v_all_true_t.graph.get_operation_by_name(
+ "root/random_uniform/RandomUniform")
+
+ self.assertEqual(False, v_true_1_t_sampler_op.get_attr("_XlaCompile"))
+ self.assertEqual(True, v_all_true_t_sampler_op.get_attr("_XlaCompile"))
+
+ self.assertEqual(True, v_true_1_t.op.get_attr("_XlaCompile"))
+ self.assertEqual(True, v_all_true_t.op.get_attr("_XlaCompile"))
+
+ # Additionally ensure that where no JIT compilation happens on the
+ # random_uniform op, the output values are identical to the case
+ # where no JIT compilation happens anywhere.
+ self.assertAllClose(v_false_1, v_false_2)
+ self.assertAllClose(v_true_1, v_true_2)
+ self.assertAllClose(v_false_1, v_true_1)
+
+ def testJITVariableSeed(self):
+ """Test that the stateful initializer is not marked for compilation.
+
+ XLA does not currently support seeded initialization and XLA initializers
+ therefore return different values than non-XLA counterparts. Here
+ we ensure that if we can disable JIT compilation for the initializers and
+ get the same variable values as if no JIT compilation happened.
+ """
+ def create_ops():
+ with variable_scope.variable_scope(
+ "root",
+ initializer=init_ops.random_uniform_initializer(
+ -0.1, 0.1, seed=2)):
+ inputs = variable_scope.get_variable("var", (1,))
+ return inputs
+ _, v_false_1 = self.compute(False, create_ops)
+ _, v_false_2 = self.compute(False, create_ops)
+ _, v_true_1 = self.compute(enable_jit_nonstateful, create_ops)
+ _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops)
+ self.assertAllClose(v_false_1, v_false_2)
+ self.assertAllClose(v_true_1, v_true_2)
+ self.assertAllClose(v_false_1, v_true_1)
+
+
+if __name__ == "__main__":
+ test.main()