aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/where_op_test.py
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-07-26 11:27:40 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-07-26 12:20:47 +0800
commit23f826271a5956982df17980bca3ac7513ec4ee4 (patch)
tree8aabd4443c311164af9431aae66a29c70a4ce2d5 /tensorflow/python/kernel_tests/where_op_test.py
parent15b155e929f2eb3e30c1194fa9afc1ea40e330a4 (diff)
A faster BatchSelectFunctor for tf.where on CPU.
Op 'tf.where(c, t, e)' supports that 't' and 'e' are N-D tensors while 'c' is a 1D tensor, which would call BatchSelectFunctor to get the result. But its basic implementation broadcasts 'c' to the same dimension with 't' and 'e', which would get bad efficiency on CPU for large tensors. Here a loop-based implementation would be adopted to make this operation faster on CPU.
Diffstat (limited to 'tensorflow/python/kernel_tests/where_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index 17575da6f1..53324d5b20 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -135,6 +135,15 @@ class WhereOpTest(test.TestCase):
tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval()
self.assertAllEqual(tf_val, np_val)
+ def testBatchSelect(self):
+ x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192]
+ c_mat = np.array([[False] * 192, [True] * 192] * 8192) # [16384, 192]
+ c_vec = np.array([False, True] * 8192) # [16384]
+ np_val = np.where(c_mat, x * x, -x)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ tf_val = array_ops.where(c_vec, x * x, -x).eval()
+ self.assertAllEqual(tf_val, np_val)
class WhereBenchmark(test.Benchmark):
@@ -163,5 +172,33 @@ class WhereBenchmark(test.Benchmark):
"Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
sys.stdout.flush()
+ def benchmarkBatchSelect(self):
+ for (m, n, use_gpu) in itertools.product(
+ [1000, 10000, 100000],
+ [10, 100, 1000],
+ [False, True]):
+ name = "m_%d_n_%d_use_gpu_%s" % (m, n, use_gpu)
+ device = "/%s:0" % ("gpu" if use_gpu else "cpu")
+ with ops.Graph().as_default():
+ with ops.device(device):
+ x_gen = random_ops.random_uniform([m, n], dtype=dtypes.float32)
+ y_gen = random_ops.random_uniform([m, n], dtype=dtypes.float32)
+ c_gen = random_ops.random_uniform([m], dtype=dtypes.float32) <= 0.5
+ x = resource_variable_ops.ResourceVariable(x_gen)
+ y = resource_variable_ops.ResourceVariable(y_gen)
+ c = resource_variable_ops.ResourceVariable(c_gen)
+ op = array_ops.where(c, x, y)
+ with session.Session() as sess:
+ x.initializer.run()
+ y.initializer.run()
+ c.initializer.run()
+ r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
+ # approximate size of output: m*n*2 floats for each axis.
+ gb_processed = m * n * 8 / 1.0e9
+ throughput = gb_processed / r["wall_time"]
+ print("Benchmark: %s \t wall_time: %0.03g s \t "
+ "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
+ sys.stdout.flush()
+
if __name__ == "__main__":
test.main()