aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-13 11:19:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-13 12:32:23 -0700
commita75b6df69e1b6965fcbac5df68e89dc3cbe9931e (patch)
tree4b45c29d17a3da62c2410fa162a4c2654ff1e02a /tensorflow/contrib/compiler
parent26c5a01eaff4d1960e675129ea89d56ce0d11e90 (diff)
[TF:XLA] Add separate_compiled_gradients to control gradient scopes.
Change: 149973410
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r--tensorflow/contrib/compiler/jit.py26
-rw-r--r--tensorflow/contrib/compiler/jit_test.py78
2 files changed, 88 insertions, 16 deletions
diff --git a/tensorflow/contrib/compiler/jit.py b/tensorflow/contrib/compiler/jit.py
index b4f96ecf9a..c516ab658d 100644
--- a/tensorflow/contrib/compiler/jit.py
+++ b/tensorflow/contrib/compiler/jit.py
@@ -36,7 +36,7 @@ class _XlaScope(object):
@contextlib.contextmanager
-def experimental_jit_scope(compile_ops=True):
+def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
"""Enable or disable JIT compilation of operators within the scope.
NOTE: This is an experimental feature.
@@ -52,10 +52,27 @@ def experimental_jit_scope(compile_ops=True):
compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
e = tf.matmul(a, b) + d # matmul is compiled, the addition is not.
+ Example of separate_compiled_gradients:
+ # In the example below, the computations for f, g and h will all be compiled
+ # in separate scopes.
+ with tf.contrib.compiler.experimental_jit_scope(
+ separate_compiled_gradients=True):
+ f = tf.matmul(a, b)
+ g = tf.gradients([f], [a, b], name='mygrads1')
+ h = tf.gradients([f], [a, b], name='mygrads2')
+
Args:
compile_ops: Whether to enable or disable compilation in the scope.
Either a Python bool, or a callable that accepts the parameter
`node_def` and returns a python bool.
+ separate_compiled_gradients: If true put each gradient subgraph into a
+ separate compilation scope. This gives fine-grained control over which
+ portions of the graph will be compiled as a single unit. Compiling
+ gradients separately may yield better performance for some graphs.
+ The scope is named based on the scope of the forward computation as well
+ as the name of the gradients. As a result, the gradients will be compiled
+ in a scope that is separate from both the forward computation, and from
+ other gradients.
Yields:
The current scope, enabling or disabling compilation.
@@ -66,7 +83,12 @@ def experimental_jit_scope(compile_ops=True):
else:
xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
- attrs = {"_XlaCompile": xla_compile}
+ attrs = {
+ "_XlaCompile":
+ xla_compile,
+ "_XlaSeparateCompiledGradients":
+ attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients))
+ }
# Find the singleton counter for the current scoped graph. If it
# doesn't exist, create one.
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 6aa86c3286..2130f32f85 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -78,17 +78,17 @@ class JITTest(test.TestCase):
v_true_1_t, v_true_1 = self.compute(enable_jit_nonstateful, create_ops)
_, v_true_2 = self.compute(enable_jit_nonstateful, create_ops)
v_all_true_t, _ = self.compute(True, create_ops)
- self.assertEqual(False, v_false_1_t.op.get_attr("_XlaCompile"))
+ self.assertFalse(v_false_1_t.op.get_attr("_XlaCompile"))
v_true_1_t_sampler_op = v_true_1_t.graph.get_operation_by_name(
"root/random_uniform/RandomUniform")
v_all_true_t_sampler_op = v_all_true_t.graph.get_operation_by_name(
"root/random_uniform/RandomUniform")
- self.assertEqual(False, v_true_1_t_sampler_op.get_attr("_XlaCompile"))
- self.assertEqual(True, v_all_true_t_sampler_op.get_attr("_XlaCompile"))
+ self.assertFalse(v_true_1_t_sampler_op.get_attr("_XlaCompile"))
+ self.assertTrue(v_all_true_t_sampler_op.get_attr("_XlaCompile"))
- self.assertEqual(True, v_true_1_t.op.get_attr("_XlaCompile"))
- self.assertEqual(True, v_all_true_t.op.get_attr("_XlaCompile"))
+ self.assertTrue(v_true_1_t.op.get_attr("_XlaCompile"))
+ self.assertTrue(v_all_true_t.op.get_attr("_XlaCompile"))
# Additionally ensure that where no JIT compilation happens on the
# random_uniform op, the output values are identical to the case
@@ -165,7 +165,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
self.assertGreater(len(c_grad_ops), 0)
self.assertGreater(len(nc_grad_ops), 0)
for cg in c_grad_ops:
- self.assertEqual(True, cg.get_attr("_XlaCompile"))
+ self.assertTrue(cg.get_attr("_XlaCompile"))
for ncg in nc_grad_ops:
with self.assertRaisesRegexp(ValueError, "No attr named"):
ncg.get_attr("_XlaCompile")
@@ -175,11 +175,33 @@ class CompilationEnabledInGradientTest(test.TestCase):
def testCompilationGradientScopeNames(self):
with self.test_session(graph=ops.Graph()):
- with jit.experimental_jit_scope(True):
+ with jit.experimental_jit_scope():
# XlaScope 0
a1 = constant_op.constant(1)
a1t = a1 + a1
- with jit.experimental_jit_scope(True):
+ with jit.experimental_jit_scope():
+ # XlaScope 1
+ a2 = constant_op.constant(1)
+ a2t = a2 + a2
+
+ self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
+ grad_a1 = gradients.gradients(a1t, a1, name="GA")[0]
+ grad_a2 = gradients.gradients(a2t, a2, name="GB")[0]
+ grad_a1 = grad_a1.op.inputs[0]
+ grad_a2 = grad_a2.op.inputs[0]
+ self.assertTrue(grad_a1.op.get_attr("_XlaCompile"))
+ self.assertTrue(grad_a2.op.get_attr("_XlaCompile"))
+ self.assertEqual(b"jit_scope_0", grad_a1.op.get_attr("_XlaScope"))
+ self.assertEqual(b"jit_scope_1", grad_a2.op.get_attr("_XlaScope"))
+
+ def testCompilationSeparateGradientScopeNames(self):
+ with self.test_session(graph=ops.Graph()):
+ with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
+ # XlaScope 0
+ a1 = constant_op.constant(1)
+ a1t = a1 + a1
+ with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
# XlaScope 1
a2 = constant_op.constant(1)
a2t = a2 + a2
@@ -190,8 +212,8 @@ class CompilationEnabledInGradientTest(test.TestCase):
grad_a2 = gradients.gradients(a2t, a2, name="GB")[0]
grad_a1 = grad_a1.op.inputs[0]
grad_a2 = grad_a2.op.inputs[0]
- self.assertEqual(True, grad_a1.op.get_attr("_XlaCompile"))
- self.assertEqual(True, grad_a2.op.get_attr("_XlaCompile"))
+ self.assertTrue(grad_a1.op.get_attr("_XlaCompile"))
+ self.assertTrue(grad_a2.op.get_attr("_XlaCompile"))
self.assertEqual(b"jit_scope_0_grad_GA",
grad_a1.op.get_attr("_XlaScope"))
self.assertEqual(b"jit_scope_1_grad_GB",
@@ -207,20 +229,48 @@ class CompilationEnabledInGradientTest(test.TestCase):
r = mulop(x, x)
g_r = gradients.gradients(r, x, name="GA")[0]
- # Ensure the forward function is compiled
+ # Ensure the forward function is compiled.
+ graph_def = r.graph.as_graph_def()
+ func_attrs = graph_def.library.function[0].attr
+ self.assertTrue(func_attrs["_XlaCompile"].b)
+ self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s)
+
+ # Ensure the gradient (SymbolicGradient) is compiled, with the same
+ # _XlaScope as the function itself.
+ grad_op = g_r.op.inputs[0].op
+ self.assertTrue(grad_op.get_attr("_XlaCompile"))
+ self.assertEqual(b"function_mulop", grad_op.get_attr("_XlaScope"))
+
+ # Ensure the ops run: grad(x1*x1) = 2*x1
+ self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))
+
+ def testPlaysNicelyWithDefunSeparateGradientScope(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ with jit.experimental_jit_scope(True): # This should be ignored
+
+ @function.Defun(
+ compiled=True, noinline=True, separate_compiled_gradients=True)
+ def mulop(x1, x2):
+ return x1 * x2
+
+ x = constant_op.constant(1.0)
+ r = mulop(x, x)
+ g_r = gradients.gradients(r, x, name="GA")[0]
+
+ # Ensure the forward function is compiled.
graph_def = r.graph.as_graph_def()
func_attrs = graph_def.library.function[0].attr
self.assertTrue(func_attrs["_XlaCompile"].b)
self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s)
- # Ensure the gradient (SymbolicGradient) is compiled
+ # Ensure the gradient (SymbolicGradient) is compiled, with a different
+ # _XlaScope from the function itself.
grad_op = g_r.op.inputs[0].op
self.assertTrue(grad_op.get_attr("_XlaCompile"))
self.assertEqual(b"function_mulop_grad_GA",
grad_op.get_attr("_XlaScope"))
- # Ensure the ops run
- # grad(x1*x1) = 2*x1
+ # Ensure the ops run: grad(x1*x1) = 2*x1
self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))