# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= """Functional operations. See the [Higher Order Functions](https://tensorflow.org/api_guides/python/functional_ops) guide. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs # pylint: disable=unused-import from tensorflow.python.ops.gen_functional_ops import remote_call # pylint: enable=unused-import from tensorflow.python.ops.gen_functional_ops import symbolic_gradient from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export # TODO(yuanbyu, mrry): Handle stride to support sliding windows. @tf_export("foldl") def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None): """foldl on the list of tensors unpacked from `elems` on dimension 0. This foldl operator repeatedly applies the callable `fn` to a sequence of elements from first to last. The elements are made of the tensors unpacked from `elems` on dimension 0. The callable fn takes two tensors as arguments. The first argument is the accumulated value computed from the preceding invocation of fn. If `initializer` is None, `elems` must contain at least one element, and its first element is used as the initializer. Suppose that `elems` is unpacked into `values`, a list of tensors. The shape of the result tensor is fn(initializer, values[0]).shape`. This method also allows multi-arity `elems` and output of `fn`. If `elems` is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The signature of `fn` may match the structure of `elems`. That is, if `elems` is `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: `fn = lambda (t1, [t2, t3, [t4, t5]]):`. Args: fn: The callable to be performed. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be the first argument to `fn`. initializer: (optional) A tensor or (possibly nested) sequence of tensors, as the initial value for the accumulator. parallel_iterations: (optional) The number of iterations allowed to run in parallel. back_prop: (optional) True enables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. name: (optional) Name prefix for the returned tensors. Returns: A tensor or (possibly nested) sequence of tensors, resulting from applying `fn` consecutively to the list of tensors unpacked from `elems`, from first to last. Raises: TypeError: if `fn` is not callable. Example: ```python elems = tf.constant([1, 2, 3, 4, 5, 6]) sum = foldl(lambda a, x: a + x, elems) # sum == 21 ``` """ if not callable(fn): raise TypeError("fn must be callable.") def create_ta(elem): return tensor_array_ops.TensorArray( dtype=elem.dtype, size=n, dynamic_size=False, infer_shape=True).unstack(elem) in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "foldl", [elems]): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode: # Any get_variable calls in fn will cache the first call locally # and not issue repeated network I/O requests for each iteration. varscope = vs.get_variable_scope() varscope_caching_device_was_none = False if varscope.caching_device is None: # TODO(ebrevdo): Change to using colocate_with here and in other # methods. varscope.set_caching_device(lambda op: op.device) varscope_caching_device_was_none = True # Convert elems to tensor array. n may be known statically. elems_flat = [ ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) ] n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0] elems_ta = nest.map_structure(create_ta, elems) if initializer is None: a = nest.map_structure(lambda elem: elem.read(0), elems_ta) i = constant_op.constant(1) else: a = initializer i = constant_op.constant(0) def compute(i, a): elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta) a = fn(a, elem_i) return [i + 1, a] _, r_a = control_flow_ops.while_loop( lambda i, a: i < n, compute, [i, a], parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory) # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode and varscope_caching_device_was_none: varscope.set_caching_device(None) return r_a @tf_export("foldr") def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None): """foldr on the list of tensors unpacked from `elems` on dimension 0. This foldr operator repeatedly applies the callable `fn` to a sequence of elements from last to first. The elements are made of the tensors unpacked from `elems`. The callable fn takes two tensors as arguments. The first argument is the accumulated value computed from the preceding invocation of fn. If `initializer` is None, `elems` must contain at least one element, and its first element is used as the initializer. Suppose that `elems` is unpacked into `values`, a list of tensors. The shape of the result tensor is `fn(initializer, values[0]).shape`. This method also allows multi-arity `elems` and output of `fn`. If `elems` is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The signature of `fn` may match the structure of `elems`. That is, if `elems` is `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: `fn = lambda (t1, [t2, t3, [t4, t5]]):`. Args: fn: The callable to be performed. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be the first argument to `fn`. initializer: (optional) A tensor or (possibly nested) sequence of tensors, as the initial value for the accumulator. parallel_iterations: (optional) The number of iterations allowed to run in parallel. back_prop: (optional) True enables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. name: (optional) Name prefix for the returned tensors. Returns: A tensor or (possibly nested) sequence of tensors, resulting from applying `fn` consecutively to the list of tensors unpacked from `elems`, from last to first. Raises: TypeError: if `fn` is not callable. Example: ```python elems = [1, 2, 3, 4, 5, 6] sum = foldr(lambda a, x: a + x, elems) # sum == 21 ``` """ if not callable(fn): raise TypeError("fn must be callable.") def create_ta(elem): return tensor_array_ops.TensorArray( dtype=elem.dtype, size=n, dynamic_size=False, infer_shape=True).unstack(elem) in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "foldr", [elems]): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode: # Any get_variable calls in fn will cache the first call locally and not # issue repeated network I/O requests for each iteration. varscope = vs.get_variable_scope() varscope_caching_device_was_none = False if varscope.caching_device is None: # TODO(ebrevdo): Change to using colocate_with here and in other # methods. varscope.set_caching_device(lambda op: op.device) varscope_caching_device_was_none = True # Convert elems to tensor array. n may be known statically. elems_flat = [ ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) ] n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0] elems_ta = nest.map_structure(create_ta, elems) if initializer is None: i = n - 1 a = nest.map_structure(lambda elem: elem.read(i), elems_ta) else: i = n a = initializer def compute(i, a): i -= 1 elem = nest.map_structure(lambda elem: elem.read(i), elems_ta) a_out = fn(a, elem) return [i, a_out] _, r_a = control_flow_ops.while_loop( lambda i, a: i > 0, compute, [i, a], parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory) # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode and varscope_caching_device_was_none: varscope.set_caching_device(None) return r_a @tf_export("map_fn") def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, swap_memory=False, infer_shape=True, name=None): """map on the list of tensors unpacked from `elems` on dimension 0. The simplest version of `map_fn` repeatedly applies the callable `fn` to a sequence of elements from first to last. The elements are made of the tensors unpacked from `elems`. `dtype` is the data type of the return value of `fn`. Users must provide `dtype` if it is different from the data type of `elems`. Suppose that `elems` is unpacked into `values`, a list of tensors. The shape of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. This method also allows multi-arity `elems` and output of `fn`. If `elems` is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The signature of `fn` may match the structure of `elems`. That is, if `elems` is `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: `fn = lambda (t1, [t2, t3, [t4, t5]]):`. Furthermore, `fn` may emit a different structure than its input. For example, `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, the `dtype` parameter is not optional: `dtype` must be a type or (possibly nested) tuple of types matching the output of `fn`. To apply a functional operation to the nonzero elements of a SparseTensor one of the following methods is recommended. First, if the function is expressible as TensorFlow ops, use ```python result = SparseTensor(input.indices, fn(input.values), input.dense_shape) ``` If, however, the function is not expressible as a TensorFlow op, then use ```python result = SparseTensor( input.indices, map_fn(fn, input.values), input.dense_shape) ``` instead. When executing eagerly, map_fn does not execute in parallel even if `parallel_iterations` is set to a value > 1. You can still get the performance benefits of running a function in parallel by using the `tf.contrib.eager.defun` decorator, ```python # Assume the function being used in map_fn is fn. # To ensure map_fn calls fn in parallel, use the defun decorator. @tf.contrib.eager.defun def func(tensor): return tf.map_fn(fn, tensor) ``` Note that if you use the defun decorator, any non-TensorFlow Python code that you may have written in your function won't get executed. See `tf.contrib.eager.defun` for more details. The recommendation would be to debug without defun but switch to defun to get performance benefits of running map_fn in parallel. Args: fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as `elems`. Its output must have the same structure as `dtype` if one is provided, otherwise it must have the same structure as `elems`. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to `fn`. dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure of Tensors differing from the structure of `elems`, then `dtype` is not optional and must have the same structure as the output of `fn`. parallel_iterations: (optional) The number of iterations allowed to run in parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1. back_prop: (optional) True enables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. infer_shape: (optional) False disables tests for consistent output shapes. name: (optional) Name prefix for the returned tensors. Returns: A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying `fn` to tensors unpacked from `elems` along the first dimension, from first to last. Raises: TypeError: if `fn` is not callable or the structure of the output of `fn` and `dtype` do not match, or if elems is a SparseTensor. ValueError: if the lengths of the output of `fn` and `dtype` do not match. Examples: ```python elems = np.array([1, 2, 3, 4, 5, 6]) squares = map_fn(lambda x: x * x, elems) # squares == [1, 4, 9, 16, 25, 36] ``` ```python elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) # alternate == [-1, 2, -3] ``` ```python elems = np.array([1, 2, 3]) alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) # alternates[0] == [1, 2, 3] # alternates[1] == [-1, -2, -3] ``` """ if not callable(fn): raise TypeError("fn must be callable.") if isinstance(elems, sparse_tensor.SparseTensor): raise TypeError( "To perform a map on the values of a sparse tensor use either " " SparseTensor(input.indices, fn(input.values), input.dense_shape) or " " SparseTensor(input.indices, map_fn(fn, input.values), " "input.dense_shape)") in_graph_mode = not context.executing_eagerly() # Set the default number of parallel_iterations depending on graph/eager mode. if in_graph_mode and not parallel_iterations: parallel_iterations = 10 elif not in_graph_mode and not parallel_iterations: parallel_iterations = 1 if not in_graph_mode and parallel_iterations > 1: logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no " "effect when executing eagerly. Consider calling map_fn" " with tf.contrib.eager.defun to execute fn in " "parallel.", 1) parallel_iterations = 1 input_is_sequence = nest.is_sequence(elems) input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] def input_pack(x): return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] if dtype is None: output_is_sequence = input_is_sequence output_flatten = input_flatten output_pack = input_pack else: output_is_sequence = nest.is_sequence(dtype) output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] def output_pack(x): return (nest.pack_sequence_as(dtype, x) if output_is_sequence else x[0]) elems_flat = input_flatten(elems) with ops.name_scope(name, "map", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode: # Any get_variable calls in fn will cache the first call locally # and not issue repeated network I/O requests for each iteration. varscope = vs.get_variable_scope() varscope_caching_device_was_none = False if varscope.caching_device is None: # TODO(ebrevdo): Change to using colocate_with here and in other # methods. varscope.set_caching_device(lambda op: op.device) varscope_caching_device_was_none = True elems_flat = [ ops.convert_to_tensor(elem, name="elem") for elem in elems_flat] dtype = dtype or input_pack([elem.dtype for elem in elems_flat]) dtype_flat = output_flatten(dtype) # Convert elems to tensor array. n may be known statically. static_shape = elems_flat[0].shape if static_shape.ndims is not None and static_shape.ndims < 1: if len(elems_flat) == 1: raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar") else: raise ValueError( "elements in elems must be 1+ dimensional Tensors, not scalars" ) n = static_shape[0].value or array_ops.shape(elems_flat[0])[0] # TensorArrays are always flat elems_ta = [ tensor_array_ops.TensorArray(dtype=elem.dtype, size=n, dynamic_size=False, infer_shape=True) for elem in elems_flat] # Unpack elements elems_ta = [ elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)] i = constant_op.constant(0) accs_ta = [ tensor_array_ops.TensorArray(dtype=dt, size=n, dynamic_size=False, infer_shape=infer_shape) for dt in dtype_flat] def compute(i, tas): """The loop body of map_fn. Args: i: the loop counter tas: the flat TensorArray accumulator list Returns: (i + 1, tas): the updated counter + updated TensorArrays Raises: TypeError: if dtype and packed_fn_values structure do not match ValueType: if dtype and packed_fn_values lengths do not match """ packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) packed_fn_values = fn(packed_values) nest.assert_same_structure(dtype or elems, packed_fn_values) flat_fn_values = output_flatten(packed_fn_values) tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)] return (i + 1, tas) _, r_a = control_flow_ops.while_loop( lambda i, _: i < n, compute, (i, accs_ta), parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, maximum_iterations=n) results_flat = [r.stack() for r in r_a] n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0] for elem in elems_flat[1:]: n_static.merge_with(elem.get_shape().with_rank_at_least(1)[0]) for r in results_flat: r.set_shape(tensor_shape.TensorShape(n_static).concatenate( r.get_shape()[1:])) # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode and varscope_caching_device_was_none: varscope.set_caching_device(None) return output_pack(results_flat) @tf_export("scan") def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, reverse=False, name=None): """scan on the list of tensors unpacked from `elems` on dimension 0. The simplest version of `scan` repeatedly applies the callable `fn` to a sequence of elements from first to last. The elements are made of the tensors unpacked from `elems` on dimension 0. The callable fn takes two tensors as arguments. The first argument is the accumulated value computed from the preceding invocation of fn. If `initializer` is None, `elems` must contain at least one element, and its first element is used as the initializer. Suppose that `elems` is unpacked into `values`, a list of tensors. The shape of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. If reverse=True, it's fn(initializer, values[-1]).shape. This method also allows multi-arity `elems` and accumulator. If `elems` is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The second argument of `fn` must match the structure of `elems`. If no `initializer` is provided, the output structure and dtypes of `fn` are assumed to be the same as its input; and in this case, the first argument of `fn` must match the structure of `elems`. If an `initializer` is provided, then the output of `fn` must have the same structure as `initializer`; and the first argument of `fn` must match this structure. For example, if `elems` is `(t1, [t2, t3])` and `initializer` is `[i1, i2]` then an appropriate signature for `fn` in `python2` is: `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the one that works in `python3`, is: `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. Args: fn: The callable to be performed. It accepts two arguments. The first will have the same structure as `initializer` if one is provided, otherwise it will have the same structure as `elems`. The second will have the same (possibly nested) structure as `elems`. Its output must have the same structure as `initializer` if one is provided, otherwise it must have the same structure as `elems`. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be the first argument to `fn`. initializer: (optional) A tensor or (possibly nested) sequence of tensors, initial value for the accumulator, and the expected output type of `fn`. parallel_iterations: (optional) The number of iterations allowed to run in parallel. back_prop: (optional) True enables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. infer_shape: (optional) False disables tests for consistent output shapes. reverse: (optional) True scans the tensor last to first (instead of first to last). name: (optional) Name prefix for the returned tensors. Returns: A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying `fn` to tensors unpacked from `elems` along the first dimension, and the previous accumulator value(s), from first to last (or last to first, if `reverse=True`). Raises: TypeError: if `fn` is not callable or the structure of the output of `fn` and `initializer` do not match. ValueError: if the lengths of the output of `fn` and `initializer` do not match. Examples: ```python elems = np.array([1, 2, 3, 4, 5, 6]) sum = scan(lambda a, x: a + x, elems) # sum == [1, 3, 6, 10, 15, 21] sum = scan(lambda a, x: a + x, elems, reverse=True) # sum == [22, 21, 18, 15, 11, 6] ``` ```python elems = np.array([1, 2, 3, 4, 5, 6]) initializer = np.array(0) sum_one = scan( lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) # sum_one == [1, 2, 3, 4, 5, 6] ``` ```python elems = np.array([1, 0, 0, 0, 0, 0]) initializer = (np.array(0), np.array(1)) fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) ``` """ if not callable(fn): raise TypeError("fn must be callable.") input_is_sequence = nest.is_sequence(elems) input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] def input_pack(x): return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] if initializer is None: output_is_sequence = input_is_sequence output_flatten = input_flatten output_pack = input_pack else: output_is_sequence = nest.is_sequence(initializer) output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] def output_pack(x): return (nest.pack_sequence_as(initializer, x) if output_is_sequence else x[0]) elems_flat = input_flatten(elems) in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "scan", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode: # Any get_variable calls in fn will cache the first call locally # and not issue repeated network I/O requests for each iteration. varscope = vs.get_variable_scope() varscope_caching_device_was_none = False if varscope.caching_device is None: # TODO(ebrevdo): Change to using colocate_with here and in other # methods. varscope.set_caching_device(lambda op: op.device) varscope_caching_device_was_none = True # Convert elems to tensor array. elems_flat = [ ops.convert_to_tensor(elem, name="elem") for elem in elems_flat] # Convert elems to tensor array. n may be known statically. n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0] # TensorArrays are always flat elems_ta = [ tensor_array_ops.TensorArray(dtype=elem.dtype, size=n, dynamic_size=False, infer_shape=True) for elem in elems_flat] # Unpack elements elems_ta = [ elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)] if initializer is None: a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] i = constant_op.constant(1) else: initializer_flat = output_flatten(initializer) a_flat = [ops.convert_to_tensor(init) for init in initializer_flat] i = constant_op.constant(0) # Create a tensor array to store the intermediate values. accs_ta = [ tensor_array_ops.TensorArray( dtype=init.dtype, size=n, element_shape=init.shape if infer_shape else None, dynamic_size=False, infer_shape=infer_shape) for init in a_flat] if initializer is None: accs_ta = [acc_ta.write(n - 1 if reverse else 0, a) for (acc_ta, a) in zip(accs_ta, a_flat)] def compute(i, a_flat, tas): """The loop body of scan. Args: i: the loop counter. a_flat: the accumulator value(s), flattened. tas: the output accumulator TensorArray(s), flattened. Returns: [i + 1, a_flat, tas]: the updated counter + new accumulator values + updated TensorArrays Raises: TypeError: if initializer and fn() output structure do not match ValueType: if initializer and fn() output lengths do not match """ packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) packed_a = output_pack(a_flat) a_out = fn(packed_a, packed_elems) nest.assert_same_structure( elems if initializer is None else initializer, a_out) flat_a_out = output_flatten(a_out) tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] if reverse: next_i = i - 1 else: next_i = i + 1 return (next_i, flat_a_out, tas) if reverse: initial_i = n - 1 - i condition = lambda i, _1, _2: i >= 0 else: initial_i = i condition = lambda i, _1, _2: i < n _, _, r_a = control_flow_ops.while_loop( condition, compute, (initial_i, a_flat, accs_ta), parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, maximum_iterations=n) results_flat = [r.stack() for r in r_a] n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0] for elem in elems_flat[1:]: n_static.merge_with(elem.get_shape().with_rank_at_least(1)[0]) for r in results_flat: r.set_shape(tensor_shape.TensorShape(n_static).concatenate( r.get_shape()[1:])) # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode and varscope_caching_device_was_none: varscope.set_caching_device(None) return output_pack(results_flat) # pylint: disable=invalid-name def If(cond, inputs, then_branch, else_branch, name=None): r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs). Args: cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is converted to a boolean according to the following rule: if the scalar is a numerical value, non-zero means True and zero means False; if the scalar is a string, non-empty means True and empty means False. inputs: A list of input tensors. then_branch: A function takes 'inputs' and returns a list of tensors, whose types are the same as what else_branch returns. else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. name: A name for the operation (optional). Returns: A list of tensors returned by either then_branch(inputs) or else_branch(inputs). """ # pylint: disable=protected-access return gen_functional_ops._if( cond, inputs, [_.type for _ in then_branch.definition.signature.output_arg], then_branch, else_branch, name=name) def Gradient(inputs, f, name=None): r"""Computes the gradient function for function f via backpropagation. Args: inputs: A list of tensors of size N + M. f: The function we want to compute the gradient for. The function 'f' must be a numerical function which takes N inputs and produces M outputs. Its gradient function 'g', which is a function taking N + M inputs and produces N outputs. I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ..., xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1, dL/dy2, ..., dL/dyM), where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the loss function). dL/dxi is the partial derivative of L with respect to xi. name: A name for the operation (optional). Returns: A list of tensors of size N. """ # TODO(zhifengc): Pretty-print the above spec in latex. # TODO(zhfiengc): Needs some math expert to say the comment above better. tlist = [_.type for _ in f.definition.signature.input_arg] return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) # pylint: disable=invalid-name,protected-access def While(input_, cond, body, name=None, hostmem=None): r"""output = input; While (Cond(output)) { output = Body(output) }. Args: input_: A list of `Tensor` objects. A list of input tensors whose types are T. cond: . A function takes 'input' and returns a tensor. If the tensor is a scalar of non-boolean, the scalar is converted to a boolean according to the following rule: if the scalar is a numerical value, non-zero means True and zero means False; if the scalar is a string, non-empty means True and empty means False. If the tensor is not a scalar, non-emptiness means True and False otherwise. body: . A function takes a list of tensors and returns another list tensors. Both lists have the same types as specified by T. name: A name for the operation (optional). hostmem: A list of integer. If i is in the list, input[i] is a host memory tensor. Returns: A list of `Tensor` objects. Has the same type as `input`. A list of output tensors whose types are T. """ ret = gen_functional_ops._while(input_, cond, body, name=name) if hostmem: input_attr = attr_value_pb2.AttrValue() input_attr.list.i.extend(hostmem) ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access output_attr = attr_value_pb2.AttrValue() output_attr.list.i.extend(hostmem) ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access return ret # b/36459430 # # Ideally, we do not need this rewrite For loop into a While loop. # However, today, if a While runs on GPU and the condition returns a # boolean, the While kernel crashes. Even if we fix the crash, the # bool needs to be copied between GPU and CPU. So, a for loop is much # preferred when running on GPU. # # On the other hand, For op has no directly XLA kernel. So, when we run # a for loop, we need to rewrite it using a While op. # # It should be possible and probably better to write a XLA C++ kernel # implementing the logic in _ForUsingWhile. def _ForUsingWhile(start, limit, delta, inputs, forbody, name=None, hostmem=None): """Helper to implement a For loop using a While.""" # To support negative delta (e.g., range(100, 0, -3)), we iterate # over the range(n) and use iter * delta + start as the real # iteration index. (e.g., for i in range(34): iter = i * (-3) + # 100). d = math_ops.abs(delta) # XLA on TPUs doesn't support integer division n = math_ops.cast( math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) / math_ops.cast(d, dtypes.float32), dtypes.int32) # Carried loop variables ("extra_args") are implicitly added to the input list # of the WhileBody function. WhileCond does not call forbody, and so does not # depend on any of forbody's extra_args. Since WhileCond and WhileBody # must have identical inputs, we have to augment the cond signature to take # the same types as the carried loop variables. body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] cond_sig = body_sig + [t.dtype for t in forbody.captured_inputs] cond_name = "%s_Cond" % forbody.name @function.Defun(*cond_sig, func_name=cond_name) def WhileCond(i, n, *args): del args return i < n body_name = "%s_Body" % forbody.name @function.Defun(*body_sig, func_name=body_name) def WhileBody(i, n, start, delta, *args): """A While wrapper for forbody that handles loop-carried captured inputs.""" for_result = forbody(start + i * delta, *args) # Nullary functions return an Operation. Normal functions can't do this # because their return values are converted to Tensors. if isinstance(for_result, ops.Operation): for_result = () # Unary functions return a single Tensor value. elif isinstance(for_result, ops.Tensor): for_result = (for_result,) extra_args = tuple(function.get_extra_args()) return (i + 1, n, start, delta) + tuple(for_result) + extra_args if hostmem is not None: hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem] else: hostmem = [0, 1, 2, 3] results = While( input_=[0, n, start, delta] + inputs + WhileBody.captured_inputs, cond=WhileCond, body=WhileBody, name=name, hostmem=hostmem) # Slice off the loop-carried captured inputs. return list(results[4:len(results) - len(WhileBody.captured_inputs)]) def For(start, limit, delta, inputs, body, name=None, hostmem=None, rewrite_with_while=None): r"""out = input; for i in range(start, limit, delta) out = body(i, out). Args: start: A `Tensor` of type `int32`. limit: A `Tensor` of type `int32`. delta: A `Tensor` of type `int32`. inputs: A list of `Tensor` objects. A list of input tensors whose types are T. body: A function takes a list of tensors and returns another list of tensors. Both lists have the same types as (int32, T...). name: A name for the operation (optional). hostmem: A list of integer. If i is in the list, inputs[i] is a host memory tensor. In other words, (i+1)-th argument of the body function is expecting a host memory. rewrite_with_while: If True, using While op to implement the For. Returns: A list of `Tensor` objects. Has the same type as `input`. A list of output tensors whose types are T. """ if rewrite_with_while: return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) if body.captured_inputs: wrapper_name = "%s_BodyWrapper" % body.name @function.Defun(*body.declared_input_types, func_name=wrapper_name) def BodyWrapper(*args): """A wrapper for body that handles loop-carried captured inputs.""" body_result = body(*args) extra_args = tuple(function.get_extra_args()) # Nullary functions return an Operation. Normal functions can't do this # because their return values are converted to Tensors. if isinstance(body_result, ops.Operation): return extra_args # Unary functions return a single Tensor value. elif not isinstance(body_result, tuple): return (body_result,) + extra_args # N-ary functions return a tuple of Tensors. else: return body_result + extra_args inputs += BodyWrapper.captured_inputs ret = gen_functional_ops._for( start, limit, delta, inputs, BodyWrapper, name=name) # Slice off the loop-carried captured inputs. ret = ret[:-len(BodyWrapper.captured_inputs)] else: ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) if hostmem: num_for_params = 3 # start/limit/delta input_attr = attr_value_pb2.AttrValue() input_attr.list.i.extend([num_for_params + i for i in hostmem]) ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access output_attr = attr_value_pb2.AttrValue() output_attr.list.i.extend(hostmem) ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access return ret # pylint: enable=invalid-name,protected-access def partitioned_call(args, f, tout=None, executing_eagerly=None): """Executes a function while respecting device annotations. Currently, only those functions that execute within the same address space can be executed. Args: args: The arguments of the function, including captured inputs. f: The function to execute; an instance of `_DefinedFunction` or `_EagerDefinedFunction`. tout: a list containing the output dtypes enums; if `None`, inferred from the signature of `f`. executing_eagerly: (Optional) A boolean indicating whether the context is executing eagerly. If `None`, fetched from the global context. Returns: The list of `Tensor`s returned by invoking `f(args)`. If the function does not return anything, then returns `None` if eager execution is enabled, or the `Operation` if not. """ if tout is None: tout = tuple(x.type for x in f.definition.signature.output_arg) if executing_eagerly is None: executing_eagerly = context.executing_eagerly() if executing_eagerly or len(tout): if f.stateful_ops: outputs = gen_functional_ops.stateful_partitioned_call( args=args, Tout=tout, f=f) else: outputs = gen_functional_ops.partitioned_call(args=args, Tout=tout, f=f) return outputs if outputs else None # The generated binding returns an empty list for functions that don't # return any Tensors, hence the need to use `create_op` directly. args = [ops.internal_convert_to_tensor(x) for x in args] tin_attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( type=[x.dtype.as_datatype_enum for x in args])) tout_attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(type=tout)) func_attr = attr_value_pb2.AttrValue( func=attr_value_pb2.NameAttrList(name=f.name)) graph = ops.get_default_graph() f.add_to_graph(graph) op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" op = graph.create_op( op_name, args, tout, compute_shapes=False, name="PartitionedFunctionCall", attrs={"Tin": tin_attr, "Tout": tout_attr, "f": func_attr}) outputs = op.outputs return outputs if outputs else op