aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/xent_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/xent_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index ac56f567ce..e1e0566124 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -57,7 +57,7 @@ class XentTest(test.TestCase):
np_loss, _ = self._npXent(np_features, np_labels, dim=dim)
with self.test_session(use_gpu=use_gpu) as sess:
loss = nn_ops.softmax_cross_entropy_with_logits(
- np_features, np_labels, dim=dim)
+ labels=np_labels, logits=np_features, dim=dim)
tf_loss = sess.run(loss)
print("np_loss:", np_loss)
print("tf_loss:", tf_loss)
@@ -166,7 +166,8 @@ class XentTest(test.TestCase):
shape=[3, 4],
dtype=dtypes.float64,
name="f")
- x = nn_ops.softmax_cross_entropy_with_logits(f, l, name="xent")
+ x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f,
+ name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)