aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Debidatta Dwibedi <debidatta@google.com>2018-09-24 20:22:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 20:26:27 -0700
commit4dc77744ff6a6854cf4aa2934eb4501bc22c3465 (patch)
tree149d889779c64a0ffe5ec5e63e8762df0c1d4650 /tensorflow/python/ops
parente9cdf9f412a3aea324a4a1655d3bffb87abaff0d (diff)
Documentation for tf.map_fn in Eager mode.
PiperOrigin-RevId: 214376416
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/functional_ops.py40
1 files changed, 37 insertions, 3 deletions
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a4e7c84ae4..119d9522bd 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.gen_functional_ops import remote_call
# pylint: enable=unused-import
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -263,7 +264,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
@tf_export("map_fn")
-def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
+def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
swap_memory=False, infer_shape=True, name=None):
"""map on the list of tensors unpacked from `elems` on dimension 0.
@@ -305,6 +306,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
instead.
+ When executing eagerly, map_fn does not execute in parallel even if
+ `parallel_iterations` is set to a value > 1. You can still get the
+ performance benefits of running a function in parallel by using the
+ `tf.contrib.eager.defun` decorator,
+
+ ```python
+ # Assume the function being used in map_fn is fn.
+ # To ensure map_fn calls fn in parallel, use the defun decorator.
+ @tf.contrib.eager.defun
+ def func(tensor):
+ return tf.map_fn(fn, tensor)
+ ```
+
+ Note that if you use the defun decorator, any non-TensorFlow Python code
+ that you may have written in your function won't get executed. See
+ `tf.contrib.eager.defun` for more details. The recommendation would be to
+ debug without defun but switch to defun to get performance benefits of
+ running map_fn in parallel.
+
Args:
fn: The callable to be performed. It accepts one argument, which will
have the same (possibly nested) structure as `elems`. Its output
@@ -317,7 +337,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
of Tensors differing from the structure of `elems`, then `dtype` is not
optional and must have the same structure as the output of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run
- in parallel.
+ in parallel. When graph building, the default value is 10. While executing
+ eagerly, the default value is set to 1.
back_prop: (optional) True enables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
@@ -363,6 +384,20 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
" SparseTensor(input.indices, map_fn(fn, input.values), "
"input.dense_shape)")
+ in_graph_mode = not context.executing_eagerly()
+ # Set the default number of parallel_iterations depending on graph/eager mode.
+ if in_graph_mode and not parallel_iterations:
+ parallel_iterations = 10
+ elif not in_graph_mode and not parallel_iterations:
+ parallel_iterations = 1
+
+ if not in_graph_mode and parallel_iterations > 1:
+ logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
+ "effect when executing eagerly. Consider calling map_fn"
+ " with tf.contrib.eager.defun to execute fn in "
+ "parallel.", 1)
+ parallel_iterations = 1
+
input_is_sequence = nest.is_sequence(elems)
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
def input_pack(x):
@@ -381,7 +416,6 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
- in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager