aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-10 17:07:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-10 17:23:01 -0800
commit46c8e8bd966286ad3f192948fa64f2e9ee8f0dcb (patch)
treeafe6517e06a82ab66bc8e3179e3001a07df5d6fa /tensorflow/contrib/compiler
parentf7ceabb5594a5f47d5db3e7d7e8c7fc1ec99be27 (diff)
Add _XlaScope attribute to jit_scope to avoid fusing separate adjacent fused blocks.
Gradients get their own separate scope based on the scope of the forward op. Provide proper XlaScope for Defuns as well (each Defun gets its own scope; their gradients get their own scope). Also move jit scope gradient unit tests out of core gradients to contrib.compiler. This is just the python side that sets the attribute; the C++ changes will come in a separate CL. Change: 147216860
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r--tensorflow/contrib/compiler/BUILD1
-rw-r--r--tensorflow/contrib/compiler/jit.py41
-rw-r--r--tensorflow/contrib/compiler/jit_test.py107
3 files changed, 141 insertions, 8 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index 2ae33250f2..388d8e6ed6 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -36,6 +36,7 @@ cuda_py_test(
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
diff --git a/tensorflow/contrib/compiler/jit.py b/tensorflow/contrib/compiler/jit.py
index 028b318e70..b4f96ecf9a 100644
--- a/tensorflow/contrib/compiler/jit.py
+++ b/tensorflow/contrib/compiler/jit.py
@@ -24,6 +24,17 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
+_XLA_SCOPE_KEY = ("__xla_scope",)
+
+
+class _XlaScope(object):
+ """Keeps track of previous XLA scope calls, and depth of current call."""
+
+ def __init__(self, count, depth):
+ self.count = count
+ self.depth = depth
+
+
@contextlib.contextmanager
def experimental_jit_scope(compile_ops=True):
"""Enable or disable JIT compilation of operators within the scope.
@@ -54,18 +65,32 @@ def experimental_jit_scope(compile_ops=True):
return attr_value_pb2.AttrValue(b=compile_ops(node_def))
else:
xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
+
attrs = {"_XlaCompile": xla_compile}
- # TODO(ebrevdo): Keep a global XlaScope counter and here create a
- # special scope that checks if already within a xla scope or creates
- # a new one with a new scope string. Add a new attr _XlaScope
- # taking this string. Modify the xla fusion to respect scope
- # boundaries. Modify gradients_impl to either create a new gradient
- # scope with a suffix from the fw scope or to try to fuse with
- # the fw scope of the given op. Should be backwards compatible to
- # avoid having to modify Defun compilation attributes.
+ # Find the singleton counter for the current scoped graph. If it
+ # doesn't exist, create one.
+ xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY)
+ if not xla_scope_counter:
+ xla_scope_counter = _XlaScope(0, 0)
+ ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter)
+ else:
+ xla_scope_counter = xla_scope_counter[0]
+
+ if xla_scope_counter.depth == 0:
+ # If we're at the root xla scope, we can increase the counter so
+ # future calls to jit_scope use a different scope value.
+ # If we're already within a scope, we'll be fusing using the scope
+ # controlled by the parent.
+ attrs["_XlaScope"] = attr_value_pb2.AttrValue(
+ s=("jit_scope_%d" % xla_scope_counter.count).encode())
+ xla_scope_counter.count += 1
+
+ xla_scope_counter.depth += 1
# pylint: disable=protected-access
with ops.get_default_graph()._attr_scope(attrs):
yield
# pylint: enable=protected-access
+
+ xla_scope_counter.depth -= 1
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 9b7b74838f..6aa86c3286 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -29,10 +29,14 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
# pylint: disable=g-import-not-at-top
from tensorflow.contrib.compiler import jit
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -93,6 +97,33 @@ class JITTest(test.TestCase):
self.assertAllClose(v_true_1, v_true_2)
self.assertAllClose(v_false_1, v_true_1)
+ def testJITXlaScope(self):
+ with self.test_session(graph=ops.Graph()):
+ with jit.experimental_jit_scope(True):
+ # XlaScope 0
+ a1 = constant_op.constant(1)
+ with jit.experimental_jit_scope(True):
+ # XlaScope 1
+ a2 = constant_op.constant(1)
+ with jit.experimental_jit_scope(True):
+ # XlaScope still 1, depth 1
+ a3 = constant_op.constant(1)
+ with jit.experimental_jit_scope(True):
+ # XlaScope still 1, depth 2
+ a4 = constant_op.constant(1)
+ # XlaScope still 1, depth 1
+ a5 = constant_op.constant(1)
+ with jit.experimental_jit_scope(True):
+ # XlaScope now 2, depth 0
+ a6 = constant_op.constant(1)
+
+ self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_1", a3.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_1", a4.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_1", a5.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_2", a6.op.get_attr("_XlaScope"))
+
def testJITVariableSeed(self):
"""Test that the stateful initializer is not marked for compilation.
@@ -117,5 +148,81 @@ class JITTest(test.TestCase):
self.assertAllClose(v_false_1, v_true_1)
+class CompilationEnabledInGradientTest(test.TestCase):
+
+ def testCompilationInGradient(self):
+ with self.test_session():
+ x = constant_op.constant(3)
+ y_nc = math_ops.add(x, x, name="not_compiled")
+ with jit.experimental_jit_scope():
+ y_c = math_ops.add(y_nc, y_nc, name="compiled")
+ x_grads = gradients.gradients([y_c], [x])[0]
+ operations = x_grads.graph.get_operations()
+ c_grad_ops = [
+ op for op in operations if "gradients/compiled" in op.name]
+ nc_grad_ops = [
+ op for op in operations if "gradients/not_compiled" in op.name]
+ self.assertGreater(len(c_grad_ops), 0)
+ self.assertGreater(len(nc_grad_ops), 0)
+ for cg in c_grad_ops:
+ self.assertEqual(True, cg.get_attr("_XlaCompile"))
+ for ncg in nc_grad_ops:
+ with self.assertRaisesRegexp(ValueError, "No attr named"):
+ ncg.get_attr("_XlaCompile")
+
+ # d/dx (4 * x)
+ self.assertAllClose(4, x_grads.eval())
+
+ def testCompilationGradientScopeNames(self):
+ with self.test_session(graph=ops.Graph()):
+ with jit.experimental_jit_scope(True):
+ # XlaScope 0
+ a1 = constant_op.constant(1)
+ a1t = a1 + a1
+ with jit.experimental_jit_scope(True):
+ # XlaScope 1
+ a2 = constant_op.constant(1)
+ a2t = a2 + a2
+
+ self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
+ grad_a1 = gradients.gradients(a1t, a1, name="GA")[0]
+ grad_a2 = gradients.gradients(a2t, a2, name="GB")[0]
+ grad_a1 = grad_a1.op.inputs[0]
+ grad_a2 = grad_a2.op.inputs[0]
+ self.assertEqual(True, grad_a1.op.get_attr("_XlaCompile"))
+ self.assertEqual(True, grad_a2.op.get_attr("_XlaCompile"))
+ self.assertEqual(b"jit_scope_0_grad_GA",
+ grad_a1.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_1_grad_GB",
+ grad_a2.op.get_attr("_XlaScope"))
+
+ def testPlaysNicelyWithDefun(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ with jit.experimental_jit_scope(True): # This should be ignored
+ @function.Defun(compiled=True, noinline=True)
+ def mulop(x1, x2):
+ return x1 * x2
+ x = constant_op.constant(1.0)
+ r = mulop(x, x)
+ g_r = gradients.gradients(r, x, name="GA")[0]
+
+ # Ensure the forward function is compiled
+ graph_def = r.graph.as_graph_def()
+ func_attrs = graph_def.library.function[0].attr
+ self.assertTrue(func_attrs["_XlaCompile"].b)
+ self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s)
+
+ # Ensure the gradient (SymbolicGradient) is compiled
+ grad_op = g_r.op.inputs[0].op
+ self.assertTrue(grad_op.get_attr("_XlaCompile"))
+ self.assertEqual(b"function_mulop_grad_GA",
+ grad_op.get_attr("_XlaScope"))
+
+ # Ensure the ops run
+ # grad(x1*x1) = 2*x1
+ self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))
+
+
if __name__ == "__main__":
test.main()