diff options
author | Rachel Lim <rachelim@google.com> | 2018-08-21 15:52:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-21 15:56:26 -0700 |
commit | 62fcb03449c13281935a154e3ea7c11614ffb678 (patch) | |
tree | 4f60de30aaf9cc243d57d139ff11296248800d60 /tensorflow/contrib/data | |
parent | 7989b2bc99aa5207dccae9e7bedfc3cb140349cd (diff) |
[tf.data] Add an optimization that vectorizes map functions and swaps the order of Map->Batch dataset transformations to Batch->Map
PiperOrigin-RevId: 209674669
Diffstat (limited to 'tensorflow/contrib/data')
4 files changed, 234 insertions, 29 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 4df75c1edb..cd46e382eb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -230,12 +230,15 @@ py_test( srcs_version = "PY2AND3", deps = [ ":stats_dataset_test_base", + ":test_utils", "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", @@ -548,3 +551,13 @@ py_test( "//tensorflow/python/data/ops:readers", ], ) + +py_library( + name = "test_utils", + srcs = ["test_utils.py"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/util:nest", + ], +) diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index a711325dae..73cde40305 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -31,47 +31,57 @@ from tensorflow.python.platform import test class MapDefunTest(test.TestCase): - def testMapDefun_Simple(self): + def testMapDefunSimple(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 - with self.test_session(): - nums = [[1, 2], [3, 4], [5, 6]] - elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") - r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] - expected = elems * 2 + 3 - self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) - def testMapDefun_MismatchedTypes(self): + def testMapDefunMismatchedTypes(self): @function.Defun(dtypes.int32) def fn(x): return math_ops.cast(x, dtypes.float64) - with self.test_session(): - nums = [1, 2, 3, 4, 5, 6] - elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") - r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] - with self.assertRaises(errors.InvalidArgumentError): - self.evaluate(r) + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefunReduceDim(self): + # Tests where the output has a different rank from the input + + @function.Defun(dtypes.int32) + def fn(x): + return array_ops.gather(x, 0) + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] + expected = constant_op.constant([1, 3, 5]) + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) - def testMapDefun_MultipleOutputs(self): + def testMapDefunMultipleOutputs(self): @function.Defun(dtypes.int32) def fn(x): return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) - with self.test_session(): - nums = [[1, 2], [3, 4], [5, 6]] - elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") - r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], - [(2,), (2,)]) - expected = [elems, elems * 2 + 3] - self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,), + (2,)]) + expected = [elems, elems * 2 + 3] + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) - def testMapDefun_ShapeInference(self): + def testMapDefunShapeInference(self): @function.Defun(dtypes.int32) def fn(x): @@ -82,7 +92,7 @@ class MapDefunTest(test.TestCase): result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] self.assertEqual(result.get_shape(), (3, 2)) - def testMapDefun_PartialShapeInference(self): + def testMapDefunPartialShapeInference(self): @function.Defun(dtypes.int32) def fn(x): @@ -92,7 +102,7 @@ class MapDefunTest(test.TestCase): result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) self.assertEqual(result[0].get_shape().as_list(), [None, 2]) - def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self): + def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self): @function.Defun(dtypes.int32, dtypes.int32) def fn(x, y): @@ -108,7 +118,7 @@ class MapDefunTest(test.TestCase): "All inputs must have the same dimension 0."): sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]}) - def testMapDefun_RaisesDefunError(self): + def testMapDefunRaisesDefunError(self): @function.Defun(dtypes.int32) def fn(x): @@ -117,9 +127,8 @@ class MapDefunTest(test.TestCase): elems = constant_op.constant([0, 0, 0, 37, 0]) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) - with self.test_session(): - with self.assertRaises(errors.InvalidArgumentError): - self.evaluate(result) + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(result) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index ae147b4fa7..76aa1c3cfd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -20,12 +20,16 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base +from tensorflow.contrib.data.python.kernel_tests import test_utils from tensorflow.contrib.data.python.ops import optimization from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -277,5 +281,124 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): "record_latency_PrefetchDataset/_6", 1) +class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): + + def _get_test_datasets(self, + base_dataset, + map_fn, + num_parallel_calls=None, + expect_optimized=True): + """Given base dataset and map fn, creates test datasets. + + Returns a tuple of (unoptimized, dataset, optimized dataset). The + unoptimized dataset has the assertion that Batch follows Map. The optimized + dataset has the assertion that Map follows Batch, and has the + "map_vectorization" optimization applied. + + Args: + base_dataset: Input dataset to map->batch + map_fn: Map function to use + num_parallel_calls: (Optional.) num_parallel_calls argument for map + expect_optimized: (Optional.) Whether we expect the optimization to take + place, in which case we will assert that Batch is followed by Map, + otherwise Map followed by Batch. Defaults to True. + + Returns: + Tuple of (unoptimized dataset, optimized dataset). + """ + map_node_name = "Map" if num_parallel_calls is None else "ParallelMap" + batch_size = 100 + + def _make_dataset(node_names): + return base_dataset.apply(optimization.assert_next(node_names)).map( + map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size) + + unoptimized = _make_dataset([map_node_name, "Batch"]) + optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else + [map_node_name, "Batch"]).apply( + optimization.optimize(["map_vectorization"])) + + return unoptimized, optimized + + @parameterized.named_parameters( + ("Basic", lambda x: (x, x + 1), None), + ("Parallel", lambda x: (x, x + 1), 12), + ("Gather", lambda x: array_ops.gather(x, 0), 12), + ) + def testOptimization(self, map_fn, num_parallel_calls): + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn, + num_parallel_calls) + self._assert_datasets_equal(unoptimized, optimized) + + def testOptimizationBadMapFn(self): + # Test map functions that give an error + def map_fn(x): + # x has leading dimension 5, this will raise an error + return array_ops.gather(x, 10) + + base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch( + 5, drop_remainder=True) + _, optimized = self._get_test_datasets(base_dataset, map_fn) + nxt = optimized.make_one_shot_iterator().get_next() + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"indices = 10 is not in \[0, 5\)"): + self.evaluate(nxt) + + def testOptimizationWithCapturedInputs(self): + # Tests that vectorization works with captured inputs + def map_fn(x): + return x + y + + y = constant_op.constant(1, shape=(2,)) + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + # TODO(rachelim): when this optimization works, turn on expect_optimized + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self._assert_datasets_equal(optimized, unoptimized) + + def testOptimizationIgnoreStateful(self): + + def map_fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return array_ops.identity(x) + + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + _, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + nxt = optimized.make_one_shot_iterator().get_next() + + # NOTE: Right now, it raises an error because we can't save datasets that + # are stateful, and we rely on this saving mechanism to optimize datasets, + # so stateful functions can't be optimized. + with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"): + self.evaluate(nxt) + + def testOptimizationIgnoreRagged(self): + # Make sure we ignore inputs that might not be uniformly sized + def map_fn(x): + return array_ops.gather(x, 0) + + # output_shape = (?,) + base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False) + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self._assert_datasets_equal(unoptimized, optimized) + + def testOptimizationIgnoreRaggedMap(self): + # Don't optimize when the output of the map fn shapes are unknown. + def map_fn(x): + return array_ops.tile(x, x) + + base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True) + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self._assert_datasets_raise_same_error(unoptimized, optimized, + errors.InvalidArgumentError) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py new file mode 100644 index 0000000000..1b962b3418 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -0,0 +1,60 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utilities for tf.data functionality.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.util import nest +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class DatasetTestBase(test.TestCase): + """Base class for dataset tests.""" + + def _assert_datasets_equal(self, dataset1, dataset2): + # TODO(rachelim): support sparse tensor outputs + next1 = dataset1.make_one_shot_iterator().get_next() + next2 = dataset2.make_one_shot_iterator().get_next() + with self.test_session() as sess: + while True: + try: + op1 = sess.run(next1) + except errors.OutOfRangeError: + with self.assertRaises(errors.OutOfRangeError): + sess.run(next2) + break + op2 = sess.run(next2) + + op1 = nest.flatten(op1) + op2 = nest.flatten(op2) + assert len(op1) == len(op2) + for i in range(len(op1)): + self.assertAllEqual(op1[i], op2[i]) + + def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class): + next1 = dataset1.make_one_shot_iterator().get_next() + next2 = dataset2.make_one_shot_iterator().get_next() + with self.test_session() as sess: + try: + sess.run(next1) + raise ValueError( + "Expected dataset to raise an error of type %s, but it did not." % + repr(exc_class)) + except exc_class as e: + # Check that the first segment of the error messages are the same. + with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]): + sess.run(next2) |