diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-04 13:01:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:12:57 -0700 |
commit | 7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch) | |
tree | 84087a64563d10c3390f991c6263c7fa2cc65b11 /tensorflow/python/data | |
parent | 074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff) |
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/data/experimental/ops/map_defun.py | 8 |
2 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py index 612ee332c4..ae9dedb0ab 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py @@ -235,6 +235,18 @@ class MapDefunTest(test_base.DatasetTestBase): sess.close() thread.join() + def testMapDefunWithCapturedInputs(self): + c = constant_op.constant(2) + + @function.Defun(dtypes.int32) + def fn(x): + return x + c + + x = constant_op.constant([1, 2, 3, 4]) + map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0] + expected = x + c + self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op)) + class MapDefunBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py index 3d0d0993c9..3ac1158d8b 100644 --- a/tensorflow/python/data/experimental/ops/map_defun.py +++ b/tensorflow/python/data/experimental/ops/map_defun.py @@ -47,10 +47,12 @@ def map_defun(fn, elems, output_dtypes, output_shapes): if not isinstance(elems, list): raise ValueError("`elems` must be a list of tensors.") if not isinstance(output_dtypes, list): - raise ValueError("`output_dtypes` must be a list of tensors.") + raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.") if not isinstance(output_shapes, list): - raise ValueError("`output_shapes` must be a list of tensors.") + raise ValueError("`output_shapes` must be a list of `tf.TensorShape` " + "objects.") elems = [ops.convert_to_tensor(e) for e in elems] output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] - return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) + return gen_dataset_ops.map_defun(elems, fn.captured_inputs, output_dtypes, + output_shapes, fn) |