aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index 270a2297b4..b7025f3802 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -17,19 +17,28 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import hashlib
+import itertools
import os
+import time
import numpy as np
+from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import error_ops
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
+_NUMPY_RANDOM_SEED = 42
+
class MapDatasetTest(test.TestCase):
@@ -135,5 +144,125 @@ class MapDatasetTest(test.TestCase):
sess.run(get_next)
+class MapDatasetBenchmark(test.Benchmark):
+
+ # The purpose of this benchmark is to compare the performance of chaining vs
+ # fusing of the map and batch transformations across various configurations.
+ #
+ # NOTE: It is recommended to build the benchmark with
+ # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
+ # and execute it on a machine with at least 32 CPU cores.
+ def benchmarkMapAndBatch(self):
+
+ # Sequential pipeline configurations.
+ seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
+ seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
+
+ # Parallel pipeline configuration.
+ par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
+ par_batch_size_series = itertools.product([32], [32], [1],
+ [128, 256, 512, 1024])
+ par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
+ par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
+
+ def name(method, label, num_calls, inter_op, element_size, batch_size):
+ return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
+ method,
+ hashlib.sha1(label).hexdigest(),
+ num_calls,
+ inter_op,
+ element_size,
+ batch_size,
+ ))
+
+ def benchmark(label, series):
+
+ print("%s:" % label)
+ for num_calls, inter_op, element_size, batch_size in series:
+
+ num_iters = 1024 // (
+ (element_size * batch_size) // min(num_calls, inter_op))
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
+ element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
+
+ chained_dataset = dataset.map(
+ math_ops.matmul,
+ num_parallel_calls=num_calls).batch(batch_size=batch_size)
+ chained_iterator = chained_dataset.make_one_shot_iterator()
+ chained_get_next = chained_iterator.get_next()
+
+ chained_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+ for _ in range(5):
+ sess.run(chained_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(chained_get_next.op)
+ end = time.time()
+ chained_deltas.append(end - start)
+
+ fused_dataset = dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=num_calls,
+ batch_size=batch_size))
+ fused_iterator = fused_dataset.make_one_shot_iterator()
+ fused_get_next = fused_iterator.get_next()
+
+ fused_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+
+ for _ in range(5):
+ sess.run(fused_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(fused_get_next.op)
+ end = time.time()
+ fused_deltas.append(end - start)
+
+ print(
+ "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
+ "element size: %d, num iters: %d\nchained wall time: %f (median), "
+ "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
+ "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
+ "chained/fused: %.2fx (median), %.2fx (mean)" %
+ (batch_size, num_calls, inter_op, element_size, num_iters,
+ np.median(chained_deltas), np.mean(chained_deltas),
+ np.std(chained_deltas), np.min(chained_deltas),
+ np.max(chained_deltas), np.median(fused_deltas),
+ np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
+ np.max(fused_deltas),
+ np.median(chained_deltas) / np.median(fused_deltas),
+ np.mean(chained_deltas) / np.mean(fused_deltas)))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(chained_deltas),
+ name=name("chained", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(fused_deltas),
+ name=name("fused", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ print("")
+
+ np.random.seed(_NUMPY_RANDOM_SEED)
+ benchmark("Sequential element size evaluation", seq_elem_size_series)
+ benchmark("Sequential batch size evaluation", seq_batch_size_series)
+ benchmark("Parallel element size evaluation", par_elem_size_series)
+ benchmark("Parallel batch size evaluation", par_batch_size_series)
+ benchmark("Transformation parallelism evaluation", par_num_calls_series)
+ benchmark("Threadpool size evaluation", par_inter_op_series)
+
if __name__ == "__main__":
test.main()