aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/xent_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/xent_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py81
1 files changed, 80 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index e3e120a4eb..60c726d54c 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -18,10 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import itertools
+import sys
+
import numpy as np
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
@@ -88,7 +94,7 @@ class XentTest(test.TestCase):
4.]]]).astype(dtype)
np_labels = np.array([[[0., 0., 0., 1.]], [[0., .5, .5,
0.]]]).astype(dtype)
- self.assertRaisesRegexp(ValueError, "must be rank 2",
+ self.assertRaisesRegexp(ValueError, "rank 2, but is rank 3",
gen_nn_ops.softmax_cross_entropy_with_logits,
np_features, np_labels)
@@ -128,6 +134,24 @@ class XentTest(test.TestCase):
self.assertAllClose(
np.array([1.3862, 1.9401]), np_loss, rtol=1.e-3, atol=1.e-3)
+ def testShapeBroadcast(self):
+ np_f = np.array([[1., 2., 3., 4.],
+ [1., 2., 3., 4.]]).astype(np.float32)
+ np_l = np.array([[0., 0., 0., 1.],
+ [0., .5, .5, 0.]]).astype(np.float32)
+ np_loss, np_backprop = self._npXent(np_f, np_l)
+ tf_f = constant_op.constant(
+ np.array([[1., 2., 3., 4.]]).astype(np.float32))
+ tf_l = constant_op.constant(
+ np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float32))
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
+ tf_f, tf_l)
+ tf_loss, tf_backprop = sess.run([loss, backprop])
+ self.assertAllCloseAccordingToType(np_loss, tf_loss)
+ self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
+
def testShapeMismatch(self):
with self.test_session():
with self.assertRaises(ValueError):
@@ -260,5 +284,60 @@ class XentTest(test.TestCase):
self.assertAllEqual(np_loss, tf_loss)
+class XentBenchmark(test.Benchmark):
+
+ def benchmarkZeroDimension(self):
+ for (m, n, p, use_gpu) in itertools.product(
+ [128],
+ [10, 100, 1000, 10000, 100000],
+ [0.001, 0.01, 0.5, 0.99, 1.0],
+ [False]):
+ k = int(p * n)
+ if k == 0:
+ continue
+ name = "zero_dimension_m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu)
+ device = "/%s:0" % ("gpu" if use_gpu else "cpu")
+ with ops.Graph().as_default():
+ with ops.device(device):
+ labels = array_ops.zeros([0, 2, 4], dtype=dtypes.float32)
+ logits = array_ops.zeros([0, 2, 4], dtype=dtypes.float32)
+ op = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ with session.Session() as sess:
+ r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
+ gb_processed_input = m * n / 1.0e9
+ throughput = gb_processed_input / r["wall_time"]
+ print("Benchmark: %s \t wall_time: %0.03g s \t "
+ "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
+ sys.stdout.flush()
+
+ def benchmarkSingleClass(self):
+ for (m, n, p, use_gpu) in itertools.product(
+ [128],
+ [10, 100, 1000, 10000, 100000],
+ [0.001, 0.01, 0.5, 0.99, 1.0],
+ [False]):
+ k = int(p * n)
+ if k == 0:
+ continue
+ name = "single_class_m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu)
+ device = "/%s:0" % ("gpu" if use_gpu else "cpu")
+ with ops.Graph().as_default():
+ with ops.device(device):
+ labels = constant_op.constant([[1.], [-1.], [0.]],
+ dtype=dtypes.float32)
+ logits = constant_op.constant([[-1.], [0.], [1.]],
+ dtype=dtypes.float32)
+ op = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ with session.Session() as sess:
+ r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
+ gb_processed_input = m * n / 1.0e9
+ throughput = gb_processed_input / r["wall_time"]
+ print("Benchmark: %s \t wall_time: %0.03g s \t "
+ "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
+ sys.stdout.flush()
+
+
if __name__ == "__main__":
test.main()