aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-10-09 11:54:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:58:43 -0700
commit072fcb995a3fd658ee2461b59b159498c710513d (patch)
treef3def3d3ac6e270ad32e428889a79d662c8bc9cf /tensorflow/python
parent12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (diff)
[tf.data] NUMA-aware MapAndBatch dataset.
PiperOrigin-RevId: 216395709
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py280
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/BUILD2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py11
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py16
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/BUILD15
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py95
-rw-r--r--tensorflow/python/data/experimental/ops/BUILD1
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py7
8 files changed, 357 insertions, 70 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
index d444c4082e..5ead6d1c75 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@@ -38,12 +39,17 @@ from tensorflow.python.platform import test
class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
- ("Default", None, None),
- ("SequentialCalls", 1, None),
- ("ParallelCalls", 2, None),
- ("ParallelBatches", None, 10),
+ ("Default", None, None, False),
+ ("SequentialCalls", 1, None, False),
+ ("ParallelCalls", 2, None, False),
+ ("ParallelBatches", None, 10, False),
+ ("DefaultNUMA", None, None, True),
+ ("SequentialCallsNUMA", 1, None, True),
+ ("ParallelCallsNUMA", 2, None, True),
+ ("ParallelBatchesNUMA", None, 10, True),
)
- def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
+ def testMapAndBatch(self, num_parallel_calls, num_parallel_batches,
+ numa_aware):
"""Test a dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset ->
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
@@ -57,14 +63,20 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
- iterator = (
+ dataset = (
dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
batching.map_and_batch(
map_func=_map_fn,
batch_size=batch_size,
num_parallel_calls=num_parallel_calls,
- num_parallel_batches=num_parallel_batches))
- .make_initializable_iterator())
+ num_parallel_batches=num_parallel_batches)))
+
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+
+ iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -115,16 +127,25 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
- ("Even", False),
- ("Uneven", True),
+ ("Even", False, False),
+ ("Uneven", True, False),
+ ("EvenNUMA", False, True),
+ ("UnevenNUMA", True, True),
)
- def testMapAndBatchPartialBatch(self, drop_remainder):
- iterator = (
+ def testMapAndBatchPartialBatch(self, drop_remainder, numa_aware):
+ dataset = (
dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(
lambda x: array_ops.reshape(x * x, [1]),
batch_size=4,
- drop_remainder=drop_remainder)).make_one_shot_iterator())
+ drop_remainder=drop_remainder)))
+
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+
if drop_remainder:
self.assertEqual([4, 1], iterator.output_shapes.as_list())
else:
@@ -138,11 +159,21 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
- def testMapAndBatchYieldsPartialBatch(self):
- iterator = (dataset_ops.Dataset.range(10)
- .apply(batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]), 4))
- .make_one_shot_iterator())
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchYieldsPartialBatch(self, numa_aware):
+ dataset = (
+ dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]), 4)))
+
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+
+ iterator = dataset.make_one_shot_iterator()
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
@@ -152,10 +183,19 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
- def testMapAndBatchParallelGetNext(self):
- iterator = (dataset_ops.Dataset.range(50000)
- .apply(batching.map_and_batch(lambda x: x, batch_size=100))
- .make_one_shot_iterator())
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchParallelGetNext(self, numa_aware):
+ dataset = dataset_ops.Dataset.range(50000).apply(
+ batching.map_and_batch(lambda x: x, batch_size=100))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+
elements = []
for _ in range(100):
elements.append(iterator.get_next())
@@ -165,17 +205,26 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
- def testMapAndBatchParallelGetNextDropRemainder(self):
- iterator = (
- dataset_ops.Dataset.range(49999).apply(
- batching.map_and_batch(
- lambda x: x, batch_size=100, drop_remainder=True))
- .make_one_shot_iterator())
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchParallelGetNextDropRemainder(self, numa_aware):
+ dataset = dataset_ops.Dataset.range(49999).apply(
+ batching.map_and_batch(
+ lambda x: x, batch_size=100, drop_remainder=True))
+
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+
elements = []
for _ in range(100):
elements.append(iterator.get_next())
@@ -185,19 +234,29 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
- def testMapAndBatchSparse(self):
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchSparse(self, numa_aware):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])
- iterator = dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
+ dataset = dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(_sparse, 5))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -214,21 +273,33 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testMapAndBatchFails(self):
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchFails(self, numa_aware):
"""Test a dataset that maps a TF function across its input elements."""
dataset = dataset_ops.Dataset.from_tensors(
array_ops.check_numerics(
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
+ dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+
init_op = iterator.initializer
with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
- def testMapAndBatchShapeMismatch(self):
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchShapeMismatch(self, numa_aware):
"""Test a dataset that maps a TF function across its input elements."""
def generator():
@@ -240,9 +311,13 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int32)
batch_size = 4
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
+ dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
@@ -251,7 +326,11 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
"number of elements does not match"):
sess.run(get_next)
- def testMapAndBatchImplicitDispose(self):
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchImplicitDispose(self, numa_aware):
# Tests whether a map and batch dataset will be cleaned up correctly when
# the pipeline does not run it until exhaustion.
# The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
@@ -266,6 +345,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
dataset = dataset.prefetch(5)
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
@@ -274,26 +357,38 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(get_next)
@parameterized.named_parameters(
- ("1", 0),
- ("2", 5),
- ("3", 10),
- ("4", 90),
- ("5", 95),
- ("6", 99),
+ ("1", 0, False),
+ ("2", 5, False),
+ ("3", 10, False),
+ ("4", 90, False),
+ ("5", 95, False),
+ ("6", 99, False),
+ ("1NUMA", 0, True),
+ ("2NUMA", 5, True),
+ ("3NUMA", 10, True),
+ ("4NUMA", 90, True),
+ ("5NUMA", 95, True),
+ ("6NUMA", 99, True),
)
- def testMapAndBatchOutOfRangeError(self, threshold):
+ def testMapAndBatchOutOfRangeError(self, threshold, numa_aware):
def raising_py_fn(i):
- if i >= threshold:
+ if i == threshold:
raise StopIteration()
+ elif i > threshold:
+ raise RuntimeError("Alternate error; you shouldn't see me! (i: %s)" % i)
else:
return i
- iterator = (
- dataset_ops.Dataset.range(100).apply(
- batching.map_and_batch(
- lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
- batch_size=10)).make_one_shot_iterator())
+ dataset = dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(
+ lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
+ batch_size=10))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
@@ -307,25 +402,42 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(get_next)
@parameterized.named_parameters(
- ("1", False, dtypes.bool),
- ("2", -42, dtypes.int8),
- ("3", -42, dtypes.int16),
- ("4", -42, dtypes.int32),
- ("5", -42, dtypes.int64),
- ("6", 42, dtypes.uint8),
- ("7", 42, dtypes.uint16),
- ("8", 42.0, dtypes.float16),
- ("9", 42.0, dtypes.float32),
- ("10", 42.0, dtypes.float64),
- ("11", b"hello", dtypes.string),
+ ("1", False, dtypes.bool, False),
+ ("2", -42, dtypes.int8, False),
+ ("3", -42, dtypes.int16, False),
+ ("4", -42, dtypes.int32, False),
+ ("5", -42, dtypes.int64, False),
+ ("6", 42, dtypes.uint8, False),
+ ("7", 42, dtypes.uint16, False),
+ ("8", 42.0, dtypes.float16, False),
+ ("9", 42.0, dtypes.float32, False),
+ ("10", 42.0, dtypes.float64, False),
+ ("11", b"hello", dtypes.string, False),
+ ("1NUMA", False, dtypes.bool, True),
+ ("2NUMA", -42, dtypes.int8, True),
+ ("3NUMA", -42, dtypes.int16, True),
+ ("4NUMA", -42, dtypes.int32, True),
+ ("5NUMA", -42, dtypes.int64, True),
+ ("6NUMA", 42, dtypes.uint8, True),
+ ("7NUMA", 42, dtypes.uint16, True),
+ ("8NUMA", 42.0, dtypes.float16, True),
+ ("9NUMA", 42.0, dtypes.float32, True),
+ ("10NUMA", 42.0, dtypes.float64, True),
+ ("11NUMA", b"hello", dtypes.string, True),
)
- def testMapAndBatchTypes(self, element, dtype):
+ def testMapAndBatchTypes(self, element, dtype, numa_aware):
+
def gen():
yield element
dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
batching.map_and_batch(lambda x: x, batch_size=10))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
@@ -363,6 +475,40 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(iterator.initializer, feed_dict={captured_t: 42})
self.assertAllEqual([42] * 10, sess.run(get_next))
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchControlFlow(self, numa_aware):
+
+ def map_fn(x):
+ previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
+ control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value
+ return return_value
+
+ dataset = dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(map_fn, batch_size=10))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for i in range(10):
+ print("Case %d" % i)
+ if i < 5:
+ self.assertAllEqual([i * 10 + j + 1 for j in range(10)],
+ sess.run(get_next))
+ else:
+ self.assertAllEqual(
+ [((i * 10) + j) * ((i * 10) + j) for j in range(10)],
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index c92bb8b9bc..5a0a73fd83 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -161,6 +161,7 @@ py_test(
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -199,6 +200,7 @@ py_test(
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
index 82516356df..d38255a6ea 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import time
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import batching
@@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class ModelDatasetTest(test_base.DatasetTestBase):
+class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testModelMap(self):
k = 1024 * 1024
@@ -82,7 +83,11 @@ class ModelDatasetTest(test_base.DatasetTestBase):
(np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
np.max(deltas)))
- def testModelMapAndBatch(self):
+ @parameterized.named_parameters(
+ ("Default", False),
+ ("NUMA", True),
+ )
+ def testModelMapAndBatch(self, numa_aware):
batch_size = 16
k = 1024 * 1024
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
@@ -95,6 +100,8 @@ class ModelDatasetTest(test_base.DatasetTestBase):
batch_size=batch_size))
options = dataset_ops.Options()
options.experimental_autotune = True
+ if numa_aware:
+ options.experimental_numa_aware = True
iterator = dataset.with_options(options).make_one_shot_iterator()
get_next = iterator.get_next()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
index 760cd8cc4e..2ef29796ab 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@@ -59,6 +60,21 @@ class OptimizeDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testNumaAwareRewrite(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(["NumaMapAndBatch"])).apply(
+ batching.map_and_batch(lambda x: x * x, 10))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testOptimizationStatefulFunction(self):
dataset = dataset_ops.Dataset.range(10).map(
lambda _: random_ops.random_uniform([])).batch(10)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index e556b65b7c..a97cff9fbb 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -307,6 +307,21 @@ py_test(
)
py_test(
+ name = "numa_map_and_batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["numa_map_and_batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
name = "map_dataset_serialization_test",
size = "medium",
srcs = ["map_dataset_serialization_test.py"],
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..04aab329cd
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testNumParallelBatches(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_batches = 2
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ ds = dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_batches=num_parallel_batches,
+ drop_remainder=drop_remainder))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ return ds.with_options(options)
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+ def testNumParallelCalls(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_calls = 7
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ ds = dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ drop_remainder=drop_remainder))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ return ds.with_options(options)
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+
+if __name__ == "__main__":
+ test.main()
+
diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD
index 915d399f1b..46a9552b61 100644
--- a/tensorflow/python/data/experimental/ops/BUILD
+++ b/tensorflow/python/data/experimental/ops/BUILD
@@ -122,6 +122,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index cf52f7529a..6195747671 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1410,6 +1410,8 @@ class Options(object):
"Whether to eliminate no-op transformations."),
("experimental_shuffle_and_repeat_fusion", bool,
"Whether to fuse shuffle and repeat transformations."),
+ ("experimental_numa_aware", bool,
+ "Whether to use NUMA-aware operations."),
]:
def _make_getter(name): # pylint: disable=no-self-argument
@@ -1458,6 +1460,9 @@ class Options(object):
for exp_opt in experimental_optimizations:
if getattr(self, "experimental_" + exp_opt):
result.append(exp_opt)
+
+ if getattr(self, "experimental_numa_aware"):
+ result.append("map_and_batch_numa_aware_replacement")
return result
def merge(self, options):
@@ -1485,7 +1490,7 @@ class Options(object):
"experimental_map_and_filter_fusion", "experimental_map_fusion",
"experimental_map_parallelization", "experimental_map_vectorization",
"experimental_noop_elimination",
- "experimental_shuffle_and_repeat_fusion"
+ "experimental_shuffle_and_repeat_fusion", "experimental_numa_aware",
]:
this = getattr(result, name)
that = getattr(other, name)