aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-04 13:01:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:12:57 -0700
commit7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch)
tree84087a64563d10c3390f991c6263c7fa2cc65b11 /tensorflow/python/data
parent074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff)
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py12
-rw-r--r--tensorflow/python/data/experimental/ops/map_defun.py8
2 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
index 612ee332c4..ae9dedb0ab 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -235,6 +235,18 @@ class MapDefunTest(test_base.DatasetTestBase):
sess.close()
thread.join()
+ def testMapDefunWithCapturedInputs(self):
+ c = constant_op.constant(2)
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return x + c
+
+ x = constant_op.constant([1, 2, 3, 4])
+ map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
+ expected = x + c
+ self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
+
class MapDefunBenchmark(test.Benchmark):
diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py
index 3d0d0993c9..3ac1158d8b 100644
--- a/tensorflow/python/data/experimental/ops/map_defun.py
+++ b/tensorflow/python/data/experimental/ops/map_defun.py
@@ -47,10 +47,12 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
if not isinstance(elems, list):
raise ValueError("`elems` must be a list of tensors.")
if not isinstance(output_dtypes, list):
- raise ValueError("`output_dtypes` must be a list of tensors.")
+ raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.")
if not isinstance(output_shapes, list):
- raise ValueError("`output_shapes` must be a list of tensors.")
+ raise ValueError("`output_shapes` must be a list of `tf.TensorShape` "
+ "objects.")
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
+ return gen_dataset_ops.map_defun(elems, fn.captured_inputs, output_dtypes,
+ output_shapes, fn)