aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-08-28 10:07:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 10:15:09 -0700
commitb7f2d11cc308631a8f0b733a1b2db39696507155 (patch)
tree7a450e82844f11eeb60df737ed65bac402c155f0 /tensorflow/contrib/data/python
parent00045099ee05f85f05c8367a122bcd9ef6fc6b07 (diff)
[tf.data] Enable optimizations for input pipelines with stateful functions.
PiperOrigin-RevId: 210559796
Diffstat (limited to 'tensorflow/contrib/data/python')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py19
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py18
3 files changed, 36 insertions, 14 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index 57bf22591a..e2c9bc82df 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -122,15 +122,12 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
[3, 4]]).repeat(5)
- _, optimized = self._get_test_datasets(
+ unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- nxt = optimized.make_one_shot_iterator().get_next()
-
- # NOTE: Right now, it raises an error because we can't save datasets that
- # are stateful, and we rely on this saving mechanism to optimize datasets,
- # so stateful functions can't be optimized.
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"):
- self.evaluate(nxt)
+ self._assert_datasets_raise_same_error(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
def testOptimizationIgnoreRagged(self):
# Make sure we ignore inputs that might not be uniformly sized
@@ -151,8 +148,10 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(unoptimized, optimized,
- errors.InvalidArgumentError)
+ self._assert_datasets_raise_same_error(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
class MapVectorizationBenchmark(test.Benchmark):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index ec43bc3653..446bf8d749 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -22,6 +22,7 @@ from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -100,6 +101,18 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testStatefulFunctionOptimization(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next([
+ "MapAndBatch"
+ ])).map(lambda _: random_ops.random_uniform([])).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
index 1b962b3418..1d70b16041 100644
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import re
+
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@@ -45,7 +47,11 @@ class DatasetTestBase(test.TestCase):
for i in range(len(op1)):
self.assertAllEqual(op1[i], op2[i])
- def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class):
+ def _assert_datasets_raise_same_error(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
next1 = dataset1.make_one_shot_iterator().get_next()
next2 = dataset2.make_one_shot_iterator().get_next()
with self.test_session() as sess:
@@ -53,8 +59,12 @@ class DatasetTestBase(test.TestCase):
sess.run(next1)
raise ValueError(
"Expected dataset to raise an error of type %s, but it did not." %
- repr(exc_class))
- except exc_class as e:
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
# Check that the first segment of the error messages are the same.
- with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]):
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
sess.run(next2)