aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-07 15:20:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 15:24:30 -0700
commit8f82f7cd650efbed650ac5756e8e1260283b2371 (patch)
tree631ee18589525c808976043b27152e049c74a686 /tensorflow/contrib/data
parent0fa21fcfa1ff9ced6b2a2f8a293558ad8845e1d5 (diff)
Accept unknown output_shapes for MapDefun op.
PiperOrigin-RevId: 212045006
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py104
-rw-r--r--tensorflow/contrib/data/python/ops/map_defun.py2
3 files changed, 105 insertions, 3 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 34f594f741..b9320e5fef 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -279,7 +279,9 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 091eb5ce37..61567bc8d7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
+
from tensorflow.contrib.data.python.ops import map_defun
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,10 +28,10 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-
class MapDefunTest(test.TestCase):
def testMapDefunSimple(self):
@@ -146,6 +149,105 @@ class MapDefunTest(test.TestCase):
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op)
+ def testMapDefunWithUnspecifiedOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ res = x * 2 + 3
+ return (res, res + 1, res + 2)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems],
+ [dtypes.int32, dtypes.int32, dtypes.int32],
+ [None, (None,), (2,)])
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+ self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+ self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+ def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ elems = array_ops.placeholder(dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+ self.assertAllEqual(
+ sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+ def testMapDefunWithWrongOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunWithInvalidInput(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2
+
+ c = constant_op.constant(2)
+ with self.assertRaises(ValueError):
+ # Fails at graph construction time for inputs with known shapes.
+ r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+ p = array_ops.placeholder(dtypes.int32)
+ r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(r, feed_dict={p: 0})
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+ def _run(self, op, name=None, num_iters=3000):
+ with session.Session() as sess:
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op)
+ start = time.time()
+ for _ in range(num_iters):
+ sess.run(op)
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmarkDefunVsMapFn(self):
+ """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+
+ @function.Defun(dtypes.int32)
+ def defun(x):
+ return array_ops.identity(x)
+
+ def map_fn(x):
+ return array_ops.identity(x)
+
+ base = math_ops.range(100)
+ for input_size in [10, 100, 1000, 10000]:
+ num_iters = 100000 // input_size
+ map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+ map_fn_op = functional_ops.map_fn(map_fn, base)
+
+ self._run(
+ map_defun_op,
+ "benchmarkMapDefun_size_%d" % input_size,
+ num_iters=num_iters)
+ self._run(
+ map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py
index 54d5cd6da0..3d0d0993c9 100644
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ b/tensorflow/contrib/data/python/ops/map_defun.py
@@ -53,6 +53,4 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- if not all(s.is_fully_defined() for s in output_shapes):
- raise ValueError("All fn output shapes must be fully defined.")
return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)