diff options
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py')
-rw-r--r-- | tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py new file mode 100644 index 0000000000..c6ee88c676 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmarks FilterDataset input pipeline op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class FilterBenchmark(test.Benchmark): + + # This benchmark compares the performance of pipeline with multiple chained + # filter with and without filter fusion. + def benchmarkFilters(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkFilters(chain_length, False) + self._benchmarkFilters(chain_length, True) + + def _benchmarkFilters(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(5).repeat(None) + for _ in range(chain_length): + dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0)) + if optimize_dataset: + dataset = dataset.apply(optimization.optimize(["filter_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(10): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Filter dataset {} chain length: {} Median wall time: {}".format( + opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_filter_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + +if __name__ == "__main__": + test.main() |