aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-10-09 11:54:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:58:43 -0700
commit072fcb995a3fd658ee2461b59b159498c710513d (patch)
treef3def3d3ac6e270ad32e428889a79d662c8bc9cf /tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py
parent12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (diff)
[tf.data] NUMA-aware MapAndBatch dataset.
PiperOrigin-RevId: 216395709
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..04aab329cd
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testNumParallelBatches(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_batches = 2
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ ds = dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_batches=num_parallel_batches,
+ drop_remainder=drop_remainder))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ return ds.with_options(options)
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+ def testNumParallelCalls(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_calls = 7
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ ds = dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ drop_remainder=drop_remainder))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ return ds.with_options(options)
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+
+if __name__ == "__main__":
+ test.main()
+