diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-13 11:19:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-13 12:32:23 -0700 |
commit | a75b6df69e1b6965fcbac5df68e89dc3cbe9931e (patch) | |
tree | 4b45c29d17a3da62c2410fa162a4c2654ff1e02a /tensorflow/contrib/compiler | |
parent | 26c5a01eaff4d1960e675129ea89d56ce0d11e90 (diff) |
[TF:XLA] Add separate_compiled_gradients to control gradient scopes.
Change: 149973410
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r-- | tensorflow/contrib/compiler/jit.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/jit_test.py | 78 |
2 files changed, 88 insertions, 16 deletions
diff --git a/tensorflow/contrib/compiler/jit.py b/tensorflow/contrib/compiler/jit.py index b4f96ecf9a..c516ab658d 100644 --- a/tensorflow/contrib/compiler/jit.py +++ b/tensorflow/contrib/compiler/jit.py @@ -36,7 +36,7 @@ class _XlaScope(object): @contextlib.contextmanager -def experimental_jit_scope(compile_ops=True): +def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False): """Enable or disable JIT compilation of operators within the scope. NOTE: This is an experimental feature. @@ -52,10 +52,27 @@ def experimental_jit_scope(compile_ops=True): compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. + Example of separate_compiled_gradients: + # In the example below, the computations for f, g and h will all be compiled + # in separate scopes. + with tf.contrib.compiler.experimental_jit_scope( + separate_compiled_gradients=True): + f = tf.matmul(a, b) + g = tf.gradients([f], [a, b], name='mygrads1') + h = tf.gradients([f], [a, b], name='mygrads2') + Args: compile_ops: Whether to enable or disable compilation in the scope. Either a Python bool, or a callable that accepts the parameter `node_def` and returns a python bool. + separate_compiled_gradients: If true put each gradient subgraph into a + separate compilation scope. This gives fine-grained control over which + portions of the graph will be compiled as a single unit. Compiling + gradients separately may yield better performance for some graphs. + The scope is named based on the scope of the forward computation as well + as the name of the gradients. As a result, the gradients will be compiled + in a scope that is separate from both the forward computation, and from + other gradients. Yields: The current scope, enabling or disabling compilation. @@ -66,7 +83,12 @@ def experimental_jit_scope(compile_ops=True): else: xla_compile = attr_value_pb2.AttrValue(b=compile_ops) - attrs = {"_XlaCompile": xla_compile} + attrs = { + "_XlaCompile": + xla_compile, + "_XlaSeparateCompiledGradients": + attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients)) + } # Find the singleton counter for the current scoped graph. If it # doesn't exist, create one. diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 6aa86c3286..2130f32f85 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -78,17 +78,17 @@ class JITTest(test.TestCase): v_true_1_t, v_true_1 = self.compute(enable_jit_nonstateful, create_ops) _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops) v_all_true_t, _ = self.compute(True, create_ops) - self.assertEqual(False, v_false_1_t.op.get_attr("_XlaCompile")) + self.assertFalse(v_false_1_t.op.get_attr("_XlaCompile")) v_true_1_t_sampler_op = v_true_1_t.graph.get_operation_by_name( "root/random_uniform/RandomUniform") v_all_true_t_sampler_op = v_all_true_t.graph.get_operation_by_name( "root/random_uniform/RandomUniform") - self.assertEqual(False, v_true_1_t_sampler_op.get_attr("_XlaCompile")) - self.assertEqual(True, v_all_true_t_sampler_op.get_attr("_XlaCompile")) + self.assertFalse(v_true_1_t_sampler_op.get_attr("_XlaCompile")) + self.assertTrue(v_all_true_t_sampler_op.get_attr("_XlaCompile")) - self.assertEqual(True, v_true_1_t.op.get_attr("_XlaCompile")) - self.assertEqual(True, v_all_true_t.op.get_attr("_XlaCompile")) + self.assertTrue(v_true_1_t.op.get_attr("_XlaCompile")) + self.assertTrue(v_all_true_t.op.get_attr("_XlaCompile")) # Additionally ensure that where no JIT compilation happens on the # random_uniform op, the output values are identical to the case @@ -165,7 +165,7 @@ class CompilationEnabledInGradientTest(test.TestCase): self.assertGreater(len(c_grad_ops), 0) self.assertGreater(len(nc_grad_ops), 0) for cg in c_grad_ops: - self.assertEqual(True, cg.get_attr("_XlaCompile")) + self.assertTrue(cg.get_attr("_XlaCompile")) for ncg in nc_grad_ops: with self.assertRaisesRegexp(ValueError, "No attr named"): ncg.get_attr("_XlaCompile") @@ -175,11 +175,33 @@ class CompilationEnabledInGradientTest(test.TestCase): def testCompilationGradientScopeNames(self): with self.test_session(graph=ops.Graph()): - with jit.experimental_jit_scope(True): + with jit.experimental_jit_scope(): # XlaScope 0 a1 = constant_op.constant(1) a1t = a1 + a1 - with jit.experimental_jit_scope(True): + with jit.experimental_jit_scope(): + # XlaScope 1 + a2 = constant_op.constant(1) + a2t = a2 + a2 + + self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) + grad_a1 = gradients.gradients(a1t, a1, name="GA")[0] + grad_a2 = gradients.gradients(a2t, a2, name="GB")[0] + grad_a1 = grad_a1.op.inputs[0] + grad_a2 = grad_a2.op.inputs[0] + self.assertTrue(grad_a1.op.get_attr("_XlaCompile")) + self.assertTrue(grad_a2.op.get_attr("_XlaCompile")) + self.assertEqual(b"jit_scope_0", grad_a1.op.get_attr("_XlaScope")) + self.assertEqual(b"jit_scope_1", grad_a2.op.get_attr("_XlaScope")) + + def testCompilationSeparateGradientScopeNames(self): + with self.test_session(graph=ops.Graph()): + with jit.experimental_jit_scope(True, separate_compiled_gradients=True): + # XlaScope 0 + a1 = constant_op.constant(1) + a1t = a1 + a1 + with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 a2 = constant_op.constant(1) a2t = a2 + a2 @@ -190,8 +212,8 @@ class CompilationEnabledInGradientTest(test.TestCase): grad_a2 = gradients.gradients(a2t, a2, name="GB")[0] grad_a1 = grad_a1.op.inputs[0] grad_a2 = grad_a2.op.inputs[0] - self.assertEqual(True, grad_a1.op.get_attr("_XlaCompile")) - self.assertEqual(True, grad_a2.op.get_attr("_XlaCompile")) + self.assertTrue(grad_a1.op.get_attr("_XlaCompile")) + self.assertTrue(grad_a2.op.get_attr("_XlaCompile")) self.assertEqual(b"jit_scope_0_grad_GA", grad_a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1_grad_GB", @@ -207,20 +229,48 @@ class CompilationEnabledInGradientTest(test.TestCase): r = mulop(x, x) g_r = gradients.gradients(r, x, name="GA")[0] - # Ensure the forward function is compiled + # Ensure the forward function is compiled. + graph_def = r.graph.as_graph_def() + func_attrs = graph_def.library.function[0].attr + self.assertTrue(func_attrs["_XlaCompile"].b) + self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s) + + # Ensure the gradient (SymbolicGradient) is compiled, with the same + # _XlaScope as the function itself. + grad_op = g_r.op.inputs[0].op + self.assertTrue(grad_op.get_attr("_XlaCompile")) + self.assertEqual(b"function_mulop", grad_op.get_attr("_XlaScope")) + + # Ensure the ops run: grad(x1*x1) = 2*x1 + self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r])) + + def testPlaysNicelyWithDefunSeparateGradientScope(self): + with self.test_session(graph=ops.Graph()) as sess: + with jit.experimental_jit_scope(True): # This should be ignored + + @function.Defun( + compiled=True, noinline=True, separate_compiled_gradients=True) + def mulop(x1, x2): + return x1 * x2 + + x = constant_op.constant(1.0) + r = mulop(x, x) + g_r = gradients.gradients(r, x, name="GA")[0] + + # Ensure the forward function is compiled. graph_def = r.graph.as_graph_def() func_attrs = graph_def.library.function[0].attr self.assertTrue(func_attrs["_XlaCompile"].b) self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s) - # Ensure the gradient (SymbolicGradient) is compiled + # Ensure the gradient (SymbolicGradient) is compiled, with a different + # _XlaScope from the function itself. grad_op = g_r.op.inputs[0].op self.assertTrue(grad_op.get_attr("_XlaCompile")) self.assertEqual(b"function_mulop_grad_GA", grad_op.get_attr("_XlaScope")) - # Ensure the ops run - # grad(x1*x1) = 2*x1 + # Ensure the ops run: grad(x1*x1) = 2*x1 self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r])) |