diff options
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py | 104 |
1 files changed, 103 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 091eb5ce37..61567bc8d7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from tensorflow.contrib.data.python.ops import map_defun +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -25,10 +28,10 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test - class MapDefunTest(test.TestCase): def testMapDefunSimple(self): @@ -146,6 +149,105 @@ class MapDefunTest(test.TestCase): r"indices = 10 is not in \[0, 5\)"): self.evaluate(map_defun_op) + def testMapDefunWithUnspecifiedOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + res = x * 2 + 3 + return (res, res + 1, res + 2) + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], + [dtypes.int32, dtypes.int32, dtypes.int32], + [None, (None,), (2,)]) + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) + self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) + self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2)) + + def testMapDefunWithDifferentOutputShapeEachRun(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + elems = array_ops.placeholder(dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] + with session.Session() as sess: + self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) + self.assertAllEqual( + sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]]) + + def testMapDefunWithWrongOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefunWithInvalidInput(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + + c = constant_op.constant(2) + with self.assertRaises(ValueError): + # Fails at graph construction time for inputs with known shapes. + r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] + p = array_ops.placeholder(dtypes.int32) + r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] + with session.Session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(r, feed_dict={p: 0}) + + +class MapDefunBenchmark(test.Benchmark): + + def _run(self, op, name=None, num_iters=3000): + with session.Session() as sess: + # Warm up the session + for _ in range(5): + sess.run(op) + start = time.time() + for _ in range(num_iters): + sess.run(op) + end = time.time() + mean_us = (end - start) * 1e6 / num_iters + self.report_benchmark( + name=name, + iters=num_iters, + wall_time=mean_us, + extras={"examples_per_sec": num_iters / (end - start)}) + + def benchmarkDefunVsMapFn(self): + """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" + + @function.Defun(dtypes.int32) + def defun(x): + return array_ops.identity(x) + + def map_fn(x): + return array_ops.identity(x) + + base = math_ops.range(100) + for input_size in [10, 100, 1000, 10000]: + num_iters = 100000 // input_size + map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) + map_fn_op = functional_ops.map_fn(map_fn, base) + + self._run( + map_defun_op, + "benchmarkMapDefun_size_%d" % input_size, + num_iters=num_iters) + self._run( + map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters) if __name__ == "__main__": test.main() |