diff options
author | 2017-10-16 18:06:20 -0700 | |
---|---|---|
committer | 2017-10-16 18:10:53 -0700 | |
commit | ecaa2eee832bd5b4286377f0f853c961c6ac2ab2 (patch) | |
tree | 892c9cc2b6cb8901dcb00ed1ea4c51b31ba6af8f /tensorflow/contrib/compiler | |
parent | 5c5dc8d5641b7c915f681109921dfb2b3e082a9b (diff) |
math_grad: Fast path for when broadcasting is not needed.
PiperOrigin-RevId: 172407754
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r-- | tensorflow/contrib/compiler/jit_test.py | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 94aff13a49..2108e42bce 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -173,12 +173,12 @@ class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): with self.test_session(): - x = constant_op.constant(3) - y_nc = math_ops.add(x, x, name="not_compiled") + x = constant_op.constant([[3]]) + y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): - y_c = math_ops.add(y_nc, y_nc, name="compiled") + y_c = math_ops.matmul(y_nc, y_nc, name="compiled") x_grads = gradients.gradients([y_c], [x])[0] - operations = x_grads.graph.get_operations() + operations = x.graph.get_operations() c_grad_ops = [ op for op in operations if "gradients/compiled" in op.name] nc_grad_ops = [ @@ -191,19 +191,19 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "No attr named"): ncg.get_attr("_XlaCompile") - # d/dx (4 * x) - self.assertAllClose(4, x_grads.eval()) + # d/dx (x ** 4) = 4 * (x ** 3) + self.assertAllClose([[108]], x_grads.eval()) def testCompilationGradientScopeNames(self): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(): # XlaScope 0 - a1 = constant_op.constant(1) - a1t = a1 + a1 + a1 = constant_op.constant([[1]]) + a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(): # XlaScope 1 - a2 = constant_op.constant(1) - a2t = a2 + a2 + a2 = constant_op.constant([[1]]) + a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) @@ -220,12 +220,12 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 0 - a1 = constant_op.constant(1) - a1t = a1 + a1 + a1 = constant_op.constant([[1]]) + a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 - a2 = constant_op.constant(1) - a2t = a2 + a2 + a2 = constant_op.constant([[1]]) + a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) |