diff options
author | 2018-09-24 20:22:15 -0700 | |
---|---|---|
committer | 2018-09-24 20:26:27 -0700 | |
commit | 4dc77744ff6a6854cf4aa2934eb4501bc22c3465 (patch) | |
tree | 149d889779c64a0ffe5ec5e63e8762df0c1d4650 /tensorflow/python/ops | |
parent | e9cdf9f412a3aea324a4a1655d3bffb87abaff0d (diff) |
Documentation for tf.map_fn in Eager mode.
PiperOrigin-RevId: 214376416
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/functional_ops.py | 40 |
1 files changed, 37 insertions, 3 deletions
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index a4e7c84ae4..119d9522bd 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.gen_functional_ops import remote_call # pylint: enable=unused-import from tensorflow.python.ops.gen_functional_ops import symbolic_gradient +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -263,7 +264,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, @tf_export("map_fn") -def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, +def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, swap_memory=False, infer_shape=True, name=None): """map on the list of tensors unpacked from `elems` on dimension 0. @@ -305,6 +306,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, instead. + When executing eagerly, map_fn does not execute in parallel even if + `parallel_iterations` is set to a value > 1. You can still get the + performance benefits of running a function in parallel by using the + `tf.contrib.eager.defun` decorator, + + ```python + # Assume the function being used in map_fn is fn. + # To ensure map_fn calls fn in parallel, use the defun decorator. + @tf.contrib.eager.defun + def func(tensor): + return tf.map_fn(fn, tensor) + ``` + + Note that if you use the defun decorator, any non-TensorFlow Python code + that you may have written in your function won't get executed. See + `tf.contrib.eager.defun` for more details. The recommendation would be to + debug without defun but switch to defun to get performance benefits of + running map_fn in parallel. + Args: fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as `elems`. Its output @@ -317,7 +337,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, of Tensors differing from the structure of `elems`, then `dtype` is not optional and must have the same structure as the output of `fn`. parallel_iterations: (optional) The number of iterations allowed to run - in parallel. + in parallel. When graph building, the default value is 10. While executing + eagerly, the default value is set to 1. back_prop: (optional) True enables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. infer_shape: (optional) False disables tests for consistent output shapes. @@ -363,6 +384,20 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, " SparseTensor(input.indices, map_fn(fn, input.values), " "input.dense_shape)") + in_graph_mode = not context.executing_eagerly() + # Set the default number of parallel_iterations depending on graph/eager mode. + if in_graph_mode and not parallel_iterations: + parallel_iterations = 10 + elif not in_graph_mode and not parallel_iterations: + parallel_iterations = 1 + + if not in_graph_mode and parallel_iterations > 1: + logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no " + "effect when executing eagerly. Consider calling map_fn" + " with tf.contrib.eager.defun to execute fn in " + "parallel.", 1) + parallel_iterations = 1 + input_is_sequence = nest.is_sequence(elems) input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] def input_pack(x): @@ -381,7 +416,6 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, elems_flat = input_flatten(elems) - in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "map", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager |