aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/xent_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/xent_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py35
1 files changed, 24 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index d037ceac61..4b3dadc112 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -157,7 +157,7 @@ class XentTest(test.TestCase):
np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64))
def testGradient(self):
- with self.test_session():
+ with self.test_session() as sess:
l = constant_op.constant(
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
shape=[3, 4],
@@ -171,14 +171,21 @@ class XentTest(test.TestCase):
x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f,
name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
+
+ # Check that no extra computation performed. When only first derivative is requested,
+ # second derivative must not be computed. So when there is no second derivative,
+ # there is no `BatchMatMul` op in the graph.
+ op_names = [op.op_def.name for op in sess.graph.get_operations() if op.op_def]
+ self.assertNotIn('BatchMatMul', op_names)
+
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
def testSecondGradient(self):
- with self.test_session():
- l = constant_op.constant([0.0, 0.0, 1.0, 0.0,
- 1.0, 0.0, 0.0, 0.0,
- 0.0, 0.5, 0.0, 0.5], shape=[12],
+ with self.test_session() as sess:
+ l = constant_op.constant([0.0, 0.0, 1.0/3, 0.0,
+ 1.0/3, 0.0, 0.0, 0.0,
+ 0.0, 0.5/3, 0.0, 0.5/3], shape=[12],
dtype=dtypes.float64, name="l")
f = constant_op.constant([0.1, 0.2, 0.3, 0.4,
0.1, 0.4, 0.9, 1.6,
@@ -186,13 +193,19 @@ class XentTest(test.TestCase):
dtype=dtypes.float64, name="f")
x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f,
name="xent")
- loss = math_ops.reduce_mean(x)
+ loss = math_ops.reduce_sum(x)
- # Taking ths second gradient should fail, since it is not
- # yet supported.
- with self.assertRaisesRegexp(LookupError,
- "explicitly disabled"):
- _ = gradients_impl.hessians(loss, [f])
+ gradients = gradients_impl.gradients(loss, [f])[0]
+
+ err = gradient_checker.compute_gradient_error(f, [12], gradients, [12])
+
+ # Check that second derivative is calculated.
+ # (it is equivalent to being `BatchMatMul` op in the graph because of implementation of xentropy grad)
+ op_names = [op.op_def.name for op in sess.graph.get_operations() if op.op_def]
+ self.assertIn('BatchMatMul', op_names)
+
+ print("cross entropy hessian err = ", err)
+ self.assertLess(err, 5e-8)
def testWrapper(self):
features = np.array(