diff options
author | Piotr Padlewski <prazek@google.com> | 2018-08-24 18:33:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-24 18:37:10 -0700 |
commit | d936338f32b826cc6b7ab835efed427d78741f81 (patch) | |
tree | bf5fe98d10555acbbd4b3992c32412eb3d05b482 /tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py | |
parent | b26553dc1f461e388963efcec9e77f5a1f61c093 (diff) |
Filter fusion
This patch introduces FilterFusion optimization which can fuse
multiple FilterDataset operations.
PiperOrigin-RevId: 210189643
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index 2d8a4a583d..586b4bee5f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -161,6 +161,64 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): self._testMapAndFilter(dataset, function, predicate) + @staticmethod + def filter_functions(): + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + tests = [] + filters = [take_all, is_zero, greater] + identity = lambda x: x + for x, predicate_1 in enumerate(filters): + for y, predicate_2 in enumerate(filters): + tests.append(("mixed_{}_{}".format(x, y), identity, + [predicate_1, predicate_2])) + for z, predicate_3 in enumerate(filters): + tests.append(("mixed_{}_{}_{}".format(x, y, z), identity, + [predicate_1, predicate_2, predicate_3])) + + take_all_multiple = lambda x, y: constant_op.constant(True) + # Multi output + tests.append(("multiOne", lambda x: (x, x), + [take_all_multiple, take_all_multiple])) + tests.append(("multiTwo", lambda x: (x, 2), [ + take_all_multiple, + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) + ])) + return tuple(tests) + + @parameterized.named_parameters(*filter_functions.__func__()) + def testFilterFusion(self, map_function, predicates): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(["Map", "Filter", + "Prefetch"])).map(map_function) + for predicate in predicates: + dataset = dataset.filter(predicate) + + dataset = dataset.prefetch(0).apply( + optimization.optimize(["filter_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(5): + r = map_function(x) + filtered = False + for predicate in predicates: + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if not sess.run(b): + filtered = True + break + + if not filtered: + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() |