diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-10 17:07:38 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-10 17:23:01 -0800 |
commit | 46c8e8bd966286ad3f192948fa64f2e9ee8f0dcb (patch) | |
tree | afe6517e06a82ab66bc8e3179e3001a07df5d6fa /tensorflow/contrib/compiler | |
parent | f7ceabb5594a5f47d5db3e7d7e8c7fc1ec99be27 (diff) |
Add _XlaScope attribute to jit_scope to avoid fusing separate adjacent fused blocks.
Gradients get their own separate scope based on the scope of the forward op.
Provide proper XlaScope for Defuns as well (each Defun gets its own scope; their gradients get their own scope).
Also move jit scope gradient unit tests out of core gradients to contrib.compiler.
This is just the python side that sets the attribute; the C++ changes will come
in a separate CL.
Change: 147216860
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r-- | tensorflow/contrib/compiler/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/jit.py | 41 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/jit_test.py | 107 |
3 files changed, 141 insertions, 8 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 2ae33250f2..388d8e6ed6 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -36,6 +36,7 @@ cuda_py_test( "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/contrib/compiler/jit.py b/tensorflow/contrib/compiler/jit.py index 028b318e70..b4f96ecf9a 100644 --- a/tensorflow/contrib/compiler/jit.py +++ b/tensorflow/contrib/compiler/jit.py @@ -24,6 +24,17 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import ops +_XLA_SCOPE_KEY = ("__xla_scope",) + + +class _XlaScope(object): + """Keeps track of previous XLA scope calls, and depth of current call.""" + + def __init__(self, count, depth): + self.count = count + self.depth = depth + + @contextlib.contextmanager def experimental_jit_scope(compile_ops=True): """Enable or disable JIT compilation of operators within the scope. @@ -54,18 +65,32 @@ def experimental_jit_scope(compile_ops=True): return attr_value_pb2.AttrValue(b=compile_ops(node_def)) else: xla_compile = attr_value_pb2.AttrValue(b=compile_ops) + attrs = {"_XlaCompile": xla_compile} - # TODO(ebrevdo): Keep a global XlaScope counter and here create a - # special scope that checks if already within a xla scope or creates - # a new one with a new scope string. Add a new attr _XlaScope - # taking this string. Modify the xla fusion to respect scope - # boundaries. Modify gradients_impl to either create a new gradient - # scope with a suffix from the fw scope or to try to fuse with - # the fw scope of the given op. Should be backwards compatible to - # avoid having to modify Defun compilation attributes. + # Find the singleton counter for the current scoped graph. If it + # doesn't exist, create one. + xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY) + if not xla_scope_counter: + xla_scope_counter = _XlaScope(0, 0) + ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter) + else: + xla_scope_counter = xla_scope_counter[0] + + if xla_scope_counter.depth == 0: + # If we're at the root xla scope, we can increase the counter so + # future calls to jit_scope use a different scope value. + # If we're already within a scope, we'll be fusing using the scope + # controlled by the parent. + attrs["_XlaScope"] = attr_value_pb2.AttrValue( + s=("jit_scope_%d" % xla_scope_counter.count).encode()) + xla_scope_counter.count += 1 + + xla_scope_counter.depth += 1 # pylint: disable=protected-access with ops.get_default_graph()._attr_scope(attrs): yield # pylint: enable=protected-access + + xla_scope_counter.depth -= 1 diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 9b7b74838f..6aa86c3286 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -29,10 +29,14 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): # pylint: disable=g-import-not-at-top from tensorflow.contrib.compiler import jit +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -93,6 +97,33 @@ class JITTest(test.TestCase): self.assertAllClose(v_true_1, v_true_2) self.assertAllClose(v_false_1, v_true_1) + def testJITXlaScope(self): + with self.test_session(graph=ops.Graph()): + with jit.experimental_jit_scope(True): + # XlaScope 0 + a1 = constant_op.constant(1) + with jit.experimental_jit_scope(True): + # XlaScope 1 + a2 = constant_op.constant(1) + with jit.experimental_jit_scope(True): + # XlaScope still 1, depth 1 + a3 = constant_op.constant(1) + with jit.experimental_jit_scope(True): + # XlaScope still 1, depth 2 + a4 = constant_op.constant(1) + # XlaScope still 1, depth 1 + a5 = constant_op.constant(1) + with jit.experimental_jit_scope(True): + # XlaScope now 2, depth 0 + a6 = constant_op.constant(1) + + self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_1", a3.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_1", a4.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_1", a5.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_2", a6.op.get_attr("_XlaScope")) + def testJITVariableSeed(self): """Test that the stateful initializer is not marked for compilation. @@ -117,5 +148,81 @@ class JITTest(test.TestCase): self.assertAllClose(v_false_1, v_true_1) +class CompilationEnabledInGradientTest(test.TestCase): + + def testCompilationInGradient(self): + with self.test_session(): + x = constant_op.constant(3) + y_nc = math_ops.add(x, x, name="not_compiled") + with jit.experimental_jit_scope(): + y_c = math_ops.add(y_nc, y_nc, name="compiled") + x_grads = gradients.gradients([y_c], [x])[0] + operations = x_grads.graph.get_operations() + c_grad_ops = [ + op for op in operations if "gradients/compiled" in op.name] + nc_grad_ops = [ + op for op in operations if "gradients/not_compiled" in op.name] + self.assertGreater(len(c_grad_ops), 0) + self.assertGreater(len(nc_grad_ops), 0) + for cg in c_grad_ops: + self.assertEqual(True, cg.get_attr("_XlaCompile")) + for ncg in nc_grad_ops: + with self.assertRaisesRegexp(ValueError, "No attr named"): + ncg.get_attr("_XlaCompile") + + # d/dx (4 * x) + self.assertAllClose(4, x_grads.eval()) + + def testCompilationGradientScopeNames(self): + with self.test_session(graph=ops.Graph()): + with jit.experimental_jit_scope(True): + # XlaScope 0 + a1 = constant_op.constant(1) + a1t = a1 + a1 + with jit.experimental_jit_scope(True): + # XlaScope 1 + a2 = constant_op.constant(1) + a2t = a2 + a2 + + self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) + grad_a1 = gradients.gradients(a1t, a1, name="GA")[0] + grad_a2 = gradients.gradients(a2t, a2, name="GB")[0] + grad_a1 = grad_a1.op.inputs[0] + grad_a2 = grad_a2.op.inputs[0] + self.assertEqual(True, grad_a1.op.get_attr("_XlaCompile")) + self.assertEqual(True, grad_a2.op.get_attr("_XlaCompile")) + self.assertEqual(b"jit_scope_0_grad_GA", + grad_a1.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_1_grad_GB", + grad_a2.op.get_attr("_XlaScope")) + + def testPlaysNicelyWithDefun(self): + with self.test_session(graph=ops.Graph()) as sess: + with jit.experimental_jit_scope(True): # This should be ignored + @function.Defun(compiled=True, noinline=True) + def mulop(x1, x2): + return x1 * x2 + x = constant_op.constant(1.0) + r = mulop(x, x) + g_r = gradients.gradients(r, x, name="GA")[0] + + # Ensure the forward function is compiled + graph_def = r.graph.as_graph_def() + func_attrs = graph_def.library.function[0].attr + self.assertTrue(func_attrs["_XlaCompile"].b) + self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s) + + # Ensure the gradient (SymbolicGradient) is compiled + grad_op = g_r.op.inputs[0].op + self.assertTrue(grad_op.get_attr("_XlaCompile")) + self.assertEqual(b"function_mulop_grad_GA", + grad_op.get_attr("_XlaScope")) + + # Ensure the ops run + # grad(x1*x1) = 2*x1 + self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r])) + + if __name__ == "__main__": test.main() |