aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-16 18:06:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 18:10:53 -0700
commitecaa2eee832bd5b4286377f0f853c961c6ac2ab2 (patch)
tree892c9cc2b6cb8901dcb00ed1ea4c51b31ba6af8f /tensorflow/contrib/compiler
parent5c5dc8d5641b7c915f681109921dfb2b3e082a9b (diff)
math_grad: Fast path for when broadcasting is not needed.
PiperOrigin-RevId: 172407754
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r--tensorflow/contrib/compiler/jit_test.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 94aff13a49..2108e42bce 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -173,12 +173,12 @@ class CompilationEnabledInGradientTest(test.TestCase):
def testCompilationInGradient(self):
with self.test_session():
- x = constant_op.constant(3)
- y_nc = math_ops.add(x, x, name="not_compiled")
+ x = constant_op.constant([[3]])
+ y_nc = math_ops.matmul(x, x, name="not_compiled")
with jit.experimental_jit_scope():
- y_c = math_ops.add(y_nc, y_nc, name="compiled")
+ y_c = math_ops.matmul(y_nc, y_nc, name="compiled")
x_grads = gradients.gradients([y_c], [x])[0]
- operations = x_grads.graph.get_operations()
+ operations = x.graph.get_operations()
c_grad_ops = [
op for op in operations if "gradients/compiled" in op.name]
nc_grad_ops = [
@@ -191,19 +191,19 @@ class CompilationEnabledInGradientTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "No attr named"):
ncg.get_attr("_XlaCompile")
- # d/dx (4 * x)
- self.assertAllClose(4, x_grads.eval())
+ # d/dx (x ** 4) = 4 * (x ** 3)
+ self.assertAllClose([[108]], x_grads.eval())
def testCompilationGradientScopeNames(self):
with self.test_session(graph=ops.Graph()):
with jit.experimental_jit_scope():
# XlaScope 0
- a1 = constant_op.constant(1)
- a1t = a1 + a1
+ a1 = constant_op.constant([[1]])
+ a1t = math_ops.matmul(a1, a1)
with jit.experimental_jit_scope():
# XlaScope 1
- a2 = constant_op.constant(1)
- a2t = a2 + a2
+ a2 = constant_op.constant([[1]])
+ a2t = math_ops.matmul(a2, a2)
self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
@@ -220,12 +220,12 @@ class CompilationEnabledInGradientTest(test.TestCase):
with self.test_session(graph=ops.Graph()):
with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
# XlaScope 0
- a1 = constant_op.constant(1)
- a1t = a1 + a1
+ a1 = constant_op.constant([[1]])
+ a1t = math_ops.matmul(a1, a1)
with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
# XlaScope 1
- a2 = constant_op.constant(1)
- a2t = a2 + a2
+ a2 = constant_op.constant([[1]])
+ a2t = math_ops.matmul(a2, a2)
self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))