aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-08-21 15:52:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 15:56:26 -0700
commit62fcb03449c13281935a154e3ea7c11614ffb678 (patch)
tree4f60de30aaf9cc243d57d139ff11296248800d60 /tensorflow/contrib/data
parent7989b2bc99aa5207dccae9e7bedfc3cb140349cd (diff)
[tf.data] Add an optimization that vectorizes map functions and swaps the order of Map->Batch dataset transformations to Batch->Map
PiperOrigin-RevId: 209674669
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py67
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py123
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py60
4 files changed, 234 insertions, 29 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 4df75c1edb..cd46e382eb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -230,12 +230,15 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":stats_dataset_test_base",
+ ":test_utils",
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
@@ -548,3 +551,13 @@ py_test(
"//tensorflow/python/data/ops:readers",
],
)
+
+py_library(
+ name = "test_utils",
+ srcs = ["test_utils.py"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index a711325dae..73cde40305 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -31,47 +31,57 @@ from tensorflow.python.platform import test
class MapDefunTest(test.TestCase):
- def testMapDefun_Simple(self):
+ def testMapDefunSimple(self):
@function.Defun(dtypes.int32)
def simple_fn(x):
return x * 2 + 3
- with self.test_session():
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
- expected = elems * 2 + 3
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
- def testMapDefun_MismatchedTypes(self):
+ def testMapDefunMismatchedTypes(self):
@function.Defun(dtypes.int32)
def fn(x):
return math_ops.cast(x, dtypes.float64)
- with self.test_session():
- nums = [1, 2, 3, 4, 5, 6]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(r)
+ nums = [1, 2, 3, 4, 5, 6]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunReduceDim(self):
+ # Tests where the output has a different rank from the input
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return array_ops.gather(x, 0)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
+ expected = constant_op.constant([1, 3, 5])
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
- def testMapDefun_MultipleOutputs(self):
+ def testMapDefunMultipleOutputs(self):
@function.Defun(dtypes.int32)
def fn(x):
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
- with self.test_session():
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64],
- [(2,), (2,)])
- expected = [elems, elems * 2 + 3]
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
+ (2,)])
+ expected = [elems, elems * 2 + 3]
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
- def testMapDefun_ShapeInference(self):
+ def testMapDefunShapeInference(self):
@function.Defun(dtypes.int32)
def fn(x):
@@ -82,7 +92,7 @@ class MapDefunTest(test.TestCase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
self.assertEqual(result.get_shape(), (3, 2))
- def testMapDefun_PartialShapeInference(self):
+ def testMapDefunPartialShapeInference(self):
@function.Defun(dtypes.int32)
def fn(x):
@@ -92,7 +102,7 @@ class MapDefunTest(test.TestCase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
- def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self):
+ def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
@function.Defun(dtypes.int32, dtypes.int32)
def fn(x, y):
@@ -108,7 +118,7 @@ class MapDefunTest(test.TestCase):
"All inputs must have the same dimension 0."):
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
- def testMapDefun_RaisesDefunError(self):
+ def testMapDefunRaisesDefunError(self):
@function.Defun(dtypes.int32)
def fn(x):
@@ -117,9 +127,8 @@ class MapDefunTest(test.TestCase):
elems = constant_op.constant([0, 0, 0, 37, 0])
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
- with self.test_session():
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(result)
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(result)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index ae147b4fa7..76aa1c3cfd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -20,12 +20,16 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
+from tensorflow.contrib.data.python.kernel_tests import test_utils
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -277,5 +281,124 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
"record_latency_PrefetchDataset/_6", 1)
+class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
+
+ def _get_test_datasets(self,
+ base_dataset,
+ map_fn,
+ num_parallel_calls=None,
+ expect_optimized=True):
+ """Given base dataset and map fn, creates test datasets.
+
+ Returns a tuple of (unoptimized, dataset, optimized dataset). The
+ unoptimized dataset has the assertion that Batch follows Map. The optimized
+ dataset has the assertion that Map follows Batch, and has the
+ "map_vectorization" optimization applied.
+
+ Args:
+ base_dataset: Input dataset to map->batch
+ map_fn: Map function to use
+ num_parallel_calls: (Optional.) num_parallel_calls argument for map
+ expect_optimized: (Optional.) Whether we expect the optimization to take
+ place, in which case we will assert that Batch is followed by Map,
+ otherwise Map followed by Batch. Defaults to True.
+
+ Returns:
+ Tuple of (unoptimized dataset, optimized dataset).
+ """
+ map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
+ batch_size = 100
+
+ def _make_dataset(node_names):
+ return base_dataset.apply(optimization.assert_next(node_names)).map(
+ map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
+
+ unoptimized = _make_dataset([map_node_name, "Batch"])
+ optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
+ [map_node_name, "Batch"]).apply(
+ optimization.optimize(["map_vectorization"]))
+
+ return unoptimized, optimized
+
+ @parameterized.named_parameters(
+ ("Basic", lambda x: (x, x + 1), None),
+ ("Parallel", lambda x: (x, x + 1), 12),
+ ("Gather", lambda x: array_ops.gather(x, 0), 12),
+ )
+ def testOptimization(self, map_fn, num_parallel_calls):
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
+ num_parallel_calls)
+ self._assert_datasets_equal(unoptimized, optimized)
+
+ def testOptimizationBadMapFn(self):
+ # Test map functions that give an error
+ def map_fn(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
+ 5, drop_remainder=True)
+ _, optimized = self._get_test_datasets(base_dataset, map_fn)
+ nxt = optimized.make_one_shot_iterator().get_next()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(nxt)
+
+ def testOptimizationWithCapturedInputs(self):
+ # Tests that vectorization works with captured inputs
+ def map_fn(x):
+ return x + y
+
+ y = constant_op.constant(1, shape=(2,))
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ # TODO(rachelim): when this optimization works, turn on expect_optimized
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_equal(optimized, unoptimized)
+
+ def testOptimizationIgnoreStateful(self):
+
+ def map_fn(x):
+ with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
+ return array_ops.identity(x)
+
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ _, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ nxt = optimized.make_one_shot_iterator().get_next()
+
+ # NOTE: Right now, it raises an error because we can't save datasets that
+ # are stateful, and we rely on this saving mechanism to optimize datasets,
+ # so stateful functions can't be optimized.
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"):
+ self.evaluate(nxt)
+
+ def testOptimizationIgnoreRagged(self):
+ # Make sure we ignore inputs that might not be uniformly sized
+ def map_fn(x):
+ return array_ops.gather(x, 0)
+
+ # output_shape = (?,)
+ base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_equal(unoptimized, optimized)
+
+ def testOptimizationIgnoreRaggedMap(self):
+ # Don't optimize when the output of the map fn shapes are unknown.
+ def map_fn(x):
+ return array_ops.tile(x, x)
+
+ base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_raise_same_error(unoptimized, optimized,
+ errors.InvalidArgumentError)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
new file mode 100644
index 0000000000..1b962b3418
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
@@ -0,0 +1,60 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for tf.data functionality."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class DatasetTestBase(test.TestCase):
+ """Base class for dataset tests."""
+
+ def _assert_datasets_equal(self, dataset1, dataset2):
+ # TODO(rachelim): support sparse tensor outputs
+ next1 = dataset1.make_one_shot_iterator().get_next()
+ next2 = dataset2.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ while True:
+ try:
+ op1 = sess.run(next1)
+ except errors.OutOfRangeError:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next2)
+ break
+ op2 = sess.run(next2)
+
+ op1 = nest.flatten(op1)
+ op2 = nest.flatten(op2)
+ assert len(op1) == len(op2)
+ for i in range(len(op1)):
+ self.assertAllEqual(op1[i], op2[i])
+
+ def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class):
+ next1 = dataset1.make_one_shot_iterator().get_next()
+ next2 = dataset2.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ try:
+ sess.run(next1)
+ raise ValueError(
+ "Expected dataset to raise an error of type %s, but it did not." %
+ repr(exc_class))
+ except exc_class as e:
+ # Check that the first segment of the error messages are the same.
+ with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]):
+ sess.run(next2)