diff options
author | Justine Tunney <jart@google.com> | 2016-12-14 16:30:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-14 16:43:13 -0800 |
commit | 5866e065bc95c1d7de8a27413b368016941889a6 (patch) | |
tree | 55b7db600e38b3a799ab39053cd99e61204f840b /tensorflow/python/kernel_tests/xent_op_test.py | |
parent | 38a664cd961762e64899187a31a1b86cbe5a992e (diff) |
Remove hourglass imports from kernel_tests
Change: 142080137
Diffstat (limited to 'tensorflow/python/kernel_tests/xent_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/xent_op_test.py | 52 |
1 files changed, 31 insertions, 21 deletions
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index 2ae4fa1396..ac56f567ce 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -12,19 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Tests for SoftmaxCrossEntropyWithLogits op.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -import tensorflow as tf +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import nn_ops +import tensorflow.python.ops.nn_grad # pylint: disable=unused-import +from tensorflow.python.platform import test -class XentTest(tf.test.TestCase): +class XentTest(test.TestCase): def _npXent(self, features, labels, dim=-1): if dim is -1: @@ -51,7 +56,7 @@ class XentTest(tf.test.TestCase): def _testXentWrapper(self, np_features, np_labels, dim=-1, use_gpu=False): np_loss, _ = self._npXent(np_features, np_labels, dim=dim) with self.test_session(use_gpu=use_gpu) as sess: - loss = tf.nn.softmax_cross_entropy_with_logits( + loss = nn_ops.softmax_cross_entropy_with_logits( np_features, np_labels, dim=dim) tf_loss = sess.run(loss) print("np_loss:", np_loss) @@ -113,12 +118,14 @@ class XentTest(tf.test.TestCase): # The loss for this batch is [0.5 * -log(0.087), 0.5 * -log(0.237)] # = [1.3862, 1.9401] np_loss, np_backprop = self._npXent(np.array(features), np.array(labels)) - self.assertAllClose(np.array([[0.25, 0.25, 0.25, -0.75], - [0.0321, -0.4129, -0.2632, 0.6439]]), - np_backprop, - rtol=1.e-3, atol=1.e-3) - self.assertAllClose(np.array([1.3862, 1.9401]), np_loss, - rtol=1.e-3, atol=1.e-3) + self.assertAllClose( + np.array([[0.25, 0.25, 0.25, -0.75], + [0.0321, -0.4129, -0.2632, 0.6439]]), + np_backprop, + rtol=1.e-3, + atol=1.e-3) + self.assertAllClose( + np.array([1.3862, 1.9401]), np_loss, rtol=1.e-3, atol=1.e-3) def testShapeMismatch(self): with self.test_session(): @@ -149,16 +156,18 @@ class XentTest(tf.test.TestCase): def testGradient(self): with self.test_session(): - l = tf.constant([0.0, 0.0, 1.0, 0.0, - 1.0, 0.0, 0.0, 0.0, - 0.0, 0.5, 0.0, 0.5], shape=[3, 4], - dtype=tf.float64, name="l") - f = tf.constant([0.1, 0.2, 0.3, 0.4, - 0.1, 0.4, 0.9, 1.6, - 0.1, 0.8, 2.7, 6.4], shape=[3, 4], - dtype=tf.float64, name="f") - x = tf.nn.softmax_cross_entropy_with_logits(f, l, name="xent") - err = tf.test.compute_gradient_error(f, [3, 4], x, [3]) + l = constant_op.constant( + [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5], + shape=[3, 4], + dtype=dtypes.float64, + name="l") + f = constant_op.constant( + [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4], + shape=[3, 4], + dtype=dtypes.float64, + name="f") + x = nn_ops.softmax_cross_entropy_with_logits(f, l, name="xent") + err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3]) print("cross entropy gradient err = ", err) self.assertLess(err, 5e-8) @@ -177,5 +186,6 @@ class XentTest(tf.test.TestCase): self._testXentWrapper(features, labels, dim=-1, use_gpu=False) self._testXentWrapper(features, labels, dim=-1, use_gpu=True) + if __name__ == "__main__": - tf.test.main() + test.main() |