aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
diff options
context:
space:
mode:
authorGravatar Piotr Padlewski <prazek@google.com>2018-08-24 18:33:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 18:37:10 -0700
commitd936338f32b826cc6b7ab835efed427d78741f81 (patch)
treebf5fe98d10555acbbd4b3992c32412eb3d05b482 /tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
parentb26553dc1f461e388963efcec9e77f5a1f61c093 (diff)
Filter fusion
This patch introduces FilterFusion optimization which can fuse multiple FilterDataset operations. PiperOrigin-RevId: 210189643
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index 2d8a4a583d..586b4bee5f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -161,6 +161,64 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
self._testMapAndFilter(dataset, function, predicate)
+ @staticmethod
+ def filter_functions():
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+
+ tests = []
+ filters = [take_all, is_zero, greater]
+ identity = lambda x: x
+ for x, predicate_1 in enumerate(filters):
+ for y, predicate_2 in enumerate(filters):
+ tests.append(("mixed_{}_{}".format(x, y), identity,
+ [predicate_1, predicate_2]))
+ for z, predicate_3 in enumerate(filters):
+ tests.append(("mixed_{}_{}_{}".format(x, y, z), identity,
+ [predicate_1, predicate_2, predicate_3]))
+
+ take_all_multiple = lambda x, y: constant_op.constant(True)
+ # Multi output
+ tests.append(("multiOne", lambda x: (x, x),
+ [take_all_multiple, take_all_multiple]))
+ tests.append(("multiTwo", lambda x: (x, 2), [
+ take_all_multiple,
+ lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+ ]))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*filter_functions.__func__())
+ def testFilterFusion(self, map_function, predicates):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Map", "Filter",
+ "Prefetch"])).map(map_function)
+ for predicate in predicates:
+ dataset = dataset.filter(predicate)
+
+ dataset = dataset.prefetch(0).apply(
+ optimization.optimize(["filter_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ for x in range(5):
+ r = map_function(x)
+ filtered = False
+ for predicate in predicates:
+ if isinstance(r, tuple):
+ b = predicate(*r) # Pass tuple as multiple arguments.
+ else:
+ b = predicate(r)
+ if not sess.run(b):
+ filtered = True
+ break
+
+ if not filtered:
+ result = sess.run(get_next)
+ self.assertAllEqual(r, result)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()