diff options
author | 2018-08-28 10:07:45 -0700 | |
---|---|---|
committer | 2018-08-28 10:15:09 -0700 | |
commit | b7f2d11cc308631a8f0b733a1b2db39696507155 (patch) | |
tree | 7a450e82844f11eeb60df737ed65bac402c155f0 /tensorflow/contrib/data/python | |
parent | 00045099ee05f85f05c8367a122bcd9ef6fc6b07 (diff) |
[tf.data] Enable optimizations for input pipelines with stateful functions.
PiperOrigin-RevId: 210559796
Diffstat (limited to 'tensorflow/contrib/data/python')
3 files changed, 36 insertions, 14 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py index 57bf22591a..e2c9bc82df 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py @@ -122,15 +122,12 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]]).repeat(5) - _, optimized = self._get_test_datasets( + unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - nxt = optimized.make_one_shot_iterator().get_next() - - # NOTE: Right now, it raises an error because we can't save datasets that - # are stateful, and we rely on this saving mechanism to optimize datasets, - # so stateful functions can't be optimized. - with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"): - self.evaluate(nxt) + self._assert_datasets_raise_same_error( + unoptimized, optimized, errors.InvalidArgumentError, + [("OneShotIterator", "OneShotIterator_1", 1), + ("IteratorGetNext", "IteratorGetNext_1", 1)]) def testOptimizationIgnoreRagged(self): # Make sure we ignore inputs that might not be uniformly sized @@ -151,8 +148,10 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True) unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_raise_same_error(unoptimized, optimized, - errors.InvalidArgumentError) + self._assert_datasets_raise_same_error( + unoptimized, optimized, errors.InvalidArgumentError, + [("OneShotIterator", "OneShotIterator_1", 1), + ("IteratorGetNext", "IteratorGetNext_1", 1)]) class MapVectorizationBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index ec43bc3653..446bf8d749 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -22,6 +22,7 @@ from absl.testing import parameterized from tensorflow.contrib.data.python.ops import optimization from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -100,6 +101,18 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testStatefulFunctionOptimization(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next([ + "MapAndBatch" + ])).map(lambda _: random_ops.random_uniform([])).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py index 1b962b3418..1d70b16041 100644 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + from tensorflow.python.data.util import nest from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -45,7 +47,11 @@ class DatasetTestBase(test.TestCase): for i in range(len(op1)): self.assertAllEqual(op1[i], op2[i]) - def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class): + def _assert_datasets_raise_same_error(self, + dataset1, + dataset2, + exception_class, + replacements=None): next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() with self.test_session() as sess: @@ -53,8 +59,12 @@ class DatasetTestBase(test.TestCase): sess.run(next1) raise ValueError( "Expected dataset to raise an error of type %s, but it did not." % - repr(exc_class)) - except exc_class as e: + repr(exception_class)) + except exception_class as e: + expected_message = e.message + for old, new, count in replacements: + expected_message = expected_message.replace(old, new, count) # Check that the first segment of the error messages are the same. - with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]): + with self.assertRaisesRegexp(exception_class, + re.escape(expected_message)): sess.run(next2) |