# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """for_loop and pfor ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.parallel_for.pfor import PFor from tensorflow.python.util import nest def for_loop(loop_fn, loop_fn_dtypes, iters): """Runs `loop_fn` `iters` times and stacks the outputs. Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and stacks corresponding outputs of the different runs. Args: loop_fn: A function that takes an int32 scalar tf.Tensor object representing the iteration number, and returns a possibly nested structure of tensor objects. The shape of these outputs should not depend on the input. loop_fn_dtypes: dtypes for the outputs of loop_fn. iters: Number of iterations for which to run loop_fn. Returns: Returns a nested structure of stacked output tensor objects with the same nested structure as the output of `loop_fn`. """ flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes) is_none_list = [] def while_body(i, *ta_list): """Body of while loop.""" fn_output = nest.flatten(loop_fn(i)) if len(fn_output) != len(flat_loop_fn_dtypes): raise ValueError( "Number of expected outputs, %d, does not match the number of " "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes), len(fn_output))) outputs = [] del is_none_list[:] is_none_list.extend([x is None for x in fn_output]) for out, ta in zip(fn_output, ta_list): # TODO(agarwal): support returning Operation objects from loop_fn. if out is not None: ta = ta.write(i, array_ops.expand_dims(out, 0)) outputs.append(ta) return tuple([i + 1] + outputs) ta_list = control_flow_ops.while_loop( lambda i, *ta: i < iters, while_body, [0] + [ tensor_array_ops.TensorArray(dtype, iters) for dtype in flat_loop_fn_dtypes ])[1:] # TODO(rachelim): enable this for sparse tensors output = [None if is_none else ta.concat() for ta, is_none in zip(ta_list, is_none_list)] return nest.pack_sequence_as(loop_fn_dtypes, output) def pfor(loop_fn, iters): """Equivalent to running `loop_fn` `iters` times and stacking the outputs. `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` times, with input from 0 to `iters - 1`, and stacking corresponding output of each iteration. However the implementation does not use a tf.while_loop. Instead it adds new operations to the graph that collectively compute the same value as what running `loop_fn` in a loop would compute. This is an experimental feature and currently has a lot of limitations: - There should be no data depenendency between the different iterations. For example, a future iteration should not depend on a value or side-effect of a previous iteration. - Stateful kernels may mostly not be supported since these often imply a data dependency or ordering of the iterations. We do support a limited set of such stateful kernels though (like RandomFoo, Variable operations like reads, etc). - Conversion works only on a limited set of kernels for which a converter has been registered. - loop_fn cannot currently contain control flow operations like tf.while_loop or tf.cond. - `loop_fn` should return nested structure of Tensors or Operations. However if an Operation is returned, it should have zero outputs. - The shape and dtype of `loop_fn` outputs should not depend on the input to loop_fn. Args: loop_fn: A function that takes an int32 scalar tf.Tensor object representing the iteration number, and returns a possibly nested structure of Tensor or Operation objects. iters: Number of iterations for which to run loop_fn. Returns: Returns a nested structure of stacked tensor objects with the same nested structure as the output of `loop_fn`. """ existing_ops = set(ops.get_default_graph().get_operations()) with ops.name_scope("loop_body"): loop_var = array_ops.placeholder(dtypes.int32, shape=[]) loop_fn_outputs = loop_fn(loop_var) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): outputs.append(converter.convert(loop_fn_output)) return nest.pack_sequence_as(loop_fn_outputs, outputs)