aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
index 760cd8cc4e..2ef29796ab 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@@ -59,6 +60,21 @@ class OptimizeDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testNumaAwareRewrite(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(["NumaMapAndBatch"])).apply(
+ batching.map_and_batch(lambda x: x * x, 10))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testOptimizationStatefulFunction(self):
dataset = dataset_ops.Dataset.range(10).map(
lambda _: random_ops.random_uniform([])).batch(10)