diff options
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py')
-rw-r--r-- | tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py index 760cd8cc4e..2ef29796ab 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops @@ -59,6 +60,21 @@ class OptimizeDatasetTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testNumaAwareRewrite(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next(["NumaMapAndBatch"])).apply( + batching.map_and_batch(lambda x: x * x, 10)) + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testOptimizationStatefulFunction(self): dataset = dataset_ops.Dataset.range(10).map( lambda _: random_ops.random_uniform([])).batch(10) |