aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/parallel_for/control_flow_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/parallel_for/control_flow_ops.py')
-rw-r--r--tensorflow/python/ops/parallel_for/control_flow_ops.py123
1 files changed, 123 insertions, 0 deletions
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
new file mode 100644
index 0000000000..ccf2eb8214
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""for_loop and pfor ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops.parallel_for.pfor import PFor
+from tensorflow.python.util import nest
+
+
+def for_loop(loop_fn, loop_fn_dtypes, iters):
+ """Runs `loop_fn` `iters` times and stacks the outputs.
+
+
+ Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
+ stacks corresponding outputs of the different runs.
+
+ Args:
+ loop_fn: A function that takes an int32 scalar tf.Tensor object representing
+ the iteration number, and returns a possibly nested structure of tensor
+ objects. The shape of these outputs should not depend on the input.
+ loop_fn_dtypes: dtypes for the outputs of loop_fn.
+ iters: Number of iterations for which to run loop_fn.
+
+ Returns:
+ Returns a nested structure of stacked output tensor objects with the same
+ nested structure as the output of `loop_fn`.
+ """
+
+ flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
+
+ def while_body(i, *ta_list):
+ """Body of while loop."""
+ fn_output = nest.flatten(loop_fn(i))
+ if len(fn_output) != len(flat_loop_fn_dtypes):
+ raise ValueError(
+ "Number of expected outputs, %d, does not match the number of "
+ "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
+ len(fn_output)))
+ outputs = []
+ for out, ta in zip(fn_output, ta_list):
+ # TODO(agarwal): support returning Operation objects from loop_fn.
+ assert isinstance(out, ops.Tensor)
+ outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
+ return tuple([i + 1] + outputs)
+
+ ta_list = control_flow_ops.while_loop(
+ lambda i, *ta: i < iters, while_body, [0] + [
+ tensor_array_ops.TensorArray(dtype, iters)
+ for dtype in flat_loop_fn_dtypes
+ ])[1:]
+
+ # TODO(rachelim): enable this for sparse tensors
+ return nest.pack_sequence_as(loop_fn_dtypes, [ta.concat() for ta in ta_list])
+
+
+def pfor(loop_fn, iters):
+ """Equivalent to running `loop_fn` `iters` times and stacking the outputs.
+
+ `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
+ times, with input from 0 to `iters - 1`, and stacking corresponding output of
+ each iteration. However the implementation does not use a tf.while_loop.
+ Instead it adds new operations to the graph that collectively compute the same
+ value as what running `loop_fn` in a loop would compute.
+
+
+ This is an experimental feature and currently has a lot of limitations:
+ - There should be no data depenendency between the different iterations. For
+ example, a future iteration should not depend on a value or side-effect of
+ a previous iteration.
+ - Stateful kernels may mostly not be supported since these often imply a
+ data dependency or ordering of the iterations. We do support a limited set
+ of such stateful kernels though (like RandomFoo, Variable operations like
+ reads, etc).
+ - Conversion works only on a limited set of kernels for which a converter
+ has been registered.
+ - loop_fn cannot currently contain control flow operations like
+ tf.while_loop or tf.cond.
+ - `loop_fn` should return nested structure of Tensors or Operations. However
+ if an Operation is returned, it should have zero outputs.
+ - The shape and dtype of `loop_fn` outputs should not depend on the input
+ to loop_fn.
+
+ Args:
+ loop_fn: A function that takes an int32 scalar tf.Tensor object representing
+ the iteration number, and returns a possibly nested structure of Tensor or
+ Operation objects.
+ iters: Number of iterations for which to run loop_fn.
+
+ Returns:
+ Returns a nested structure of stacked tensor objects with the same nested
+ structure as the output of `loop_fn`.
+ """
+ existing_ops = set(ops.get_default_graph().get_operations())
+ with ops.name_scope("loop_body"):
+ loop_var = array_ops.placeholder(dtypes.int32, shape=[])
+ loop_fn_outputs = loop_fn(loop_var)
+ new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
+ iters = ops.convert_to_tensor(iters)
+ with ops.name_scope("pfor"):
+ converter = PFor(loop_var, iters, new_ops)
+ outputs = []
+ for loop_fn_output in nest.flatten(loop_fn_outputs):
+ outputs.append(converter.convert(loop_fn_output))
+ return nest.pack_sequence_as(loop_fn_outputs, outputs)