diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-01-30 13:42:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-30 14:04:51 -0800 |
commit | 696dad03880e9b3367b6c4b4c3903d6aa723d7e5 (patch) | |
tree | 84486bcd7e4d2bb46b2bac24e62a48be28da21d4 /tensorflow/contrib/compiler | |
parent | 44642329df6bc1627c54f92c5f6850e5882da991 (diff) |
Allow callable in hidden attr_scope; allow callable in experimental_jit_scope that can determine when an op should be compiled.
Add some unit tests of jit scope that check that compilation can be controlled using this new mechanism.
Change: 146034274
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r-- | tensorflow/contrib/compiler/BUILD | 25 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/jit.py | 30 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/jit_test.py | 121 |
3 files changed, 171 insertions, 5 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index bc3b60ef28..2ae33250f2 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -8,6 +8,8 @@ package_group( packages = ["//tensorflow/..."], ) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + py_library( name = "compiler_py", srcs = [ @@ -21,6 +23,29 @@ py_library( ], ) +cuda_py_test( + name = "jit_test", + size = "small", + srcs = ["jit_test.py"], + additional_deps = [ + ":compiler_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], + xla_enabled = True, +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/compiler/jit.py b/tensorflow/contrib/compiler/jit.py index f13c1af885..028b318e70 100644 --- a/tensorflow/contrib/compiler/jit.py +++ b/tensorflow/contrib/compiler/jit.py @@ -33,18 +33,38 @@ def experimental_jit_scope(compile_ops=True): The compilation is a hint and only supported on a best-effort basis. Example usage: - with tf.contrib.framework.experimental_jit_scope(): + with tf.contrib.compiler.experimental_jit_scope(): c = tf.matmul(a, b) # compiled - with tf.contrib.framework.experimental_jit_scope(compile_ops=False): - d = tf.matmul(a, c) # not compiled + with tf.contrib.compiler.experimental_jit_scope(compile_ops=False): + d = tf.matmul(a, c) # not compiled + with tf.contrib.compiler.experimental_jit_scope( + compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): + e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. Args: - compile_ops: boolean, whether to enable or disable compilation in the scope. + compile_ops: Whether to enable or disable compilation in the scope. + Either a Python bool, or a callable that accepts the parameter + `node_def` and returns a python bool. Yields: The current scope, enabling or disabling compilation. """ - attrs = {"_XlaCompile": attr_value_pb2.AttrValue(b=compile_ops)} + if callable(compile_ops): + def xla_compile(node_def): + return attr_value_pb2.AttrValue(b=compile_ops(node_def)) + else: + xla_compile = attr_value_pb2.AttrValue(b=compile_ops) + attrs = {"_XlaCompile": xla_compile} + + # TODO(ebrevdo): Keep a global XlaScope counter and here create a + # special scope that checks if already within a xla scope or creates + # a new one with a new scope string. Add a new attr _XlaScope + # taking this string. Modify the xla fusion to respect scope + # boundaries. Modify gradients_impl to either create a new gradient + # scope with a suffix from the fw scope or to try to fuse with + # the fw scope of the given op. Should be backwards compatible to + # avoid having to modify Defun compilation attributes. + # pylint: disable=protected-access with ops.get_default_graph()._attr_scope(attrs): yield diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py new file mode 100644 index 0000000000..9b7b74838f --- /dev/null +++ b/tensorflow/contrib/compiler/jit_test.py @@ -0,0 +1,121 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.compiler.jit.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + + +# TODO(keveman): #6568 Remove this hack that makes dlopen() not crash. +if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): + import ctypes # pylint: disable=g-import-not-at-top + sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) + + +# pylint: disable=g-import-not-at-top +from tensorflow.contrib.compiler import jit +from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +# pylint: enable=g-import-not-at-top + + +_REGISTERED_OPS = op_def_registry.get_registered_ops() + + +def enable_jit_nonstateful(node_def): + try: + return not _REGISTERED_OPS[node_def.op].is_stateful + except KeyError: + raise ValueError("Unregistered op being created: %s" % node_def) + + +class JITTest(test.TestCase): + + def compute(self, use_jit, compute_fn): + random_seed.set_random_seed(1234) + with self.test_session(graph=ops.Graph()) as sess: + with jit.experimental_jit_scope(use_jit): + r = compute_fn() + sess.run(variables.global_variables_initializer()) + return (r, sess.run(r)) + + def testJITCreateOpsLambda(self): + """Test several ways of customizing the compilation attribute.""" + def create_ops(): + with variable_scope.variable_scope( + "root", + initializer=init_ops.random_uniform_initializer( + -0.1, 0.1, seed=2)): + inputs = random_ops.random_uniform((1,), seed=1) + return inputs + v_false_1_t, v_false_1 = self.compute(False, create_ops) + _, v_false_2 = self.compute(False, create_ops) + v_true_1_t, v_true_1 = self.compute(enable_jit_nonstateful, create_ops) + _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops) + v_all_true_t, _ = self.compute(True, create_ops) + self.assertEqual(False, v_false_1_t.op.get_attr("_XlaCompile")) + v_true_1_t_sampler_op = v_true_1_t.graph.get_operation_by_name( + "root/random_uniform/RandomUniform") + v_all_true_t_sampler_op = v_all_true_t.graph.get_operation_by_name( + "root/random_uniform/RandomUniform") + + self.assertEqual(False, v_true_1_t_sampler_op.get_attr("_XlaCompile")) + self.assertEqual(True, v_all_true_t_sampler_op.get_attr("_XlaCompile")) + + self.assertEqual(True, v_true_1_t.op.get_attr("_XlaCompile")) + self.assertEqual(True, v_all_true_t.op.get_attr("_XlaCompile")) + + # Additionally ensure that where no JIT compilation happens on the + # random_uniform op, the output values are identical to the case + # where no JIT compilation happens anywhere. + self.assertAllClose(v_false_1, v_false_2) + self.assertAllClose(v_true_1, v_true_2) + self.assertAllClose(v_false_1, v_true_1) + + def testJITVariableSeed(self): + """Test that the stateful initializer is not marked for compilation. + + XLA does not currently support seeded initialization and XLA initializers + therefore return different values than non-XLA counterparts. Here + we ensure that if we can disable JIT compilation for the initializers and + get the same variable values as if no JIT compilation happened. + """ + def create_ops(): + with variable_scope.variable_scope( + "root", + initializer=init_ops.random_uniform_initializer( + -0.1, 0.1, seed=2)): + inputs = variable_scope.get_variable("var", (1,)) + return inputs + _, v_false_1 = self.compute(False, create_ops) + _, v_false_2 = self.compute(False, create_ops) + _, v_true_1 = self.compute(enable_jit_nonstateful, create_ops) + _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops) + self.assertAllClose(v_false_1, v_false_2) + self.assertAllClose(v_true_1, v_true_2) + self.assertAllClose(v_false_1, v_true_1) + + +if __name__ == "__main__": + test.main() |