diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/xent_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/xent_op_test.py | 35 |
1 files changed, 24 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index d037ceac61..4b3dadc112 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -157,7 +157,7 @@ class XentTest(test.TestCase): np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64)) def testGradient(self): - with self.test_session(): + with self.test_session() as sess: l = constant_op.constant( [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5], shape=[3, 4], @@ -171,14 +171,21 @@ class XentTest(test.TestCase): x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f, name="xent") err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3]) + + # Check that no extra computation performed. When only first derivative is requested, + # second derivative must not be computed. So when there is no second derivative, + # there is no `BatchMatMul` op in the graph. + op_names = [op.op_def.name for op in sess.graph.get_operations() if op.op_def] + self.assertNotIn('BatchMatMul', op_names) + print("cross entropy gradient err = ", err) self.assertLess(err, 5e-8) def testSecondGradient(self): - with self.test_session(): - l = constant_op.constant([0.0, 0.0, 1.0, 0.0, - 1.0, 0.0, 0.0, 0.0, - 0.0, 0.5, 0.0, 0.5], shape=[12], + with self.test_session() as sess: + l = constant_op.constant([0.0, 0.0, 1.0/3, 0.0, + 1.0/3, 0.0, 0.0, 0.0, + 0.0, 0.5/3, 0.0, 0.5/3], shape=[12], dtype=dtypes.float64, name="l") f = constant_op.constant([0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, @@ -186,13 +193,19 @@ class XentTest(test.TestCase): dtype=dtypes.float64, name="f") x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f, name="xent") - loss = math_ops.reduce_mean(x) + loss = math_ops.reduce_sum(x) - # Taking ths second gradient should fail, since it is not - # yet supported. - with self.assertRaisesRegexp(LookupError, - "explicitly disabled"): - _ = gradients_impl.hessians(loss, [f]) + gradients = gradients_impl.gradients(loss, [f])[0] + + err = gradient_checker.compute_gradient_error(f, [12], gradients, [12]) + + # Check that second derivative is calculated. + # (it is equivalent to being `BatchMatMul` op in the graph because of implementation of xentropy grad) + op_names = [op.op_def.name for op in sess.graph.get_operations() if op.op_def] + self.assertIn('BatchMatMul', op_names) + + print("cross entropy hessian err = ", err) + self.assertLess(err, 5e-8) def testWrapper(self): features = np.array( |