aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py104
1 files changed, 103 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 091eb5ce37..61567bc8d7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
+
from tensorflow.contrib.data.python.ops import map_defun
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,10 +28,10 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-
class MapDefunTest(test.TestCase):
def testMapDefunSimple(self):
@@ -146,6 +149,105 @@ class MapDefunTest(test.TestCase):
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op)
+ def testMapDefunWithUnspecifiedOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ res = x * 2 + 3
+ return (res, res + 1, res + 2)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems],
+ [dtypes.int32, dtypes.int32, dtypes.int32],
+ [None, (None,), (2,)])
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+ self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+ self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+ def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ elems = array_ops.placeholder(dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+ self.assertAllEqual(
+ sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+ def testMapDefunWithWrongOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunWithInvalidInput(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2
+
+ c = constant_op.constant(2)
+ with self.assertRaises(ValueError):
+ # Fails at graph construction time for inputs with known shapes.
+ r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+ p = array_ops.placeholder(dtypes.int32)
+ r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(r, feed_dict={p: 0})
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+ def _run(self, op, name=None, num_iters=3000):
+ with session.Session() as sess:
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op)
+ start = time.time()
+ for _ in range(num_iters):
+ sess.run(op)
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmarkDefunVsMapFn(self):
+ """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+
+ @function.Defun(dtypes.int32)
+ def defun(x):
+ return array_ops.identity(x)
+
+ def map_fn(x):
+ return array_ops.identity(x)
+
+ base = math_ops.range(100)
+ for input_size in [10, 100, 1000, 10000]:
+ num_iters = 100000 // input_size
+ map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+ map_fn_op = functional_ops.map_fn(map_fn, base)
+
+ self._run(
+ map_defun_op,
+ "benchmarkMapDefun_size_%d" % input_size,
+ num_iters=num_iters)
+ self._run(
+ map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
if __name__ == "__main__":
test.main()